diff --git a/realman_src/realman_aloha/__init__.py b/realman_src/realman_aloha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_camera/.gitignore b/realman_src/realman_aloha/shadow_camera/.gitignore new file mode 100644 index 000000000..17567f898 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +*.pyc +*.pyo +*.pt \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_camera/README.md b/realman_src/realman_aloha/shadow_camera/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_camera/__init__.py b/realman_src/realman_aloha/shadow_camera/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_camera/pyproject.toml b/realman_src/realman_aloha/shadow_camera/pyproject.toml new file mode 100644 index 000000000..d34c92730 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/pyproject.toml @@ -0,0 +1,33 @@ +[tool.poetry] +name = "shadow_camera" +version = "0.1.0" +description = "camera class, currently includes realsense" +readme = "README.md" +authors = ["Shadow "] +license = "MIT" +#include = ["realman_vision/pytransform/_pytransform.so",] +classifiers = [ + "Operating System :: POSIX :: Linux amd64", + "Programming Language :: Python :: 3.10", +] + +[tool.poetry.dependencies] +python = ">=3.9" +numpy = ">=2.0.1" +opencv-python = ">=4.10.0.84" +pyrealsense2 = ">=2.55.1.6486" + +[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。 +pytest = ">=8.3" +black = ">=24.10.0" + +[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。 + + +[tool.poetry.group.dev.dependencies] + + + +[build-system] +requires = ["poetry-core>=1.8.4"] +build-backend = "poetry.core.masonry.api" diff --git a/realman_src/realman_aloha/shadow_camera/src/__init__.py b/realman_src/realman_aloha/shadow_camera/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/__init__.py b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/base_camera.py b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/base_camera.py new file mode 100644 index 000000000..6e44fa476 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/base_camera.py @@ -0,0 +1,38 @@ +from abc import ABCMeta, abstractmethod + + +class BaseCamera(metaclass=ABCMeta): + """摄像头基类""" + + def __init__(self): + pass + + @abstractmethod + def start_camera(self): + """启动相机""" + pass + + @abstractmethod + def stop_camera(self): + """停止相机""" + pass + + @abstractmethod + def set_resolution(self, resolution_width, resolution_height): + """设置相机彩色图像分辨率""" + pass + + @abstractmethod + def set_frame_rate(self, fps): + """设置相机彩色图像帧率""" + pass + + @abstractmethod + def read_frame(self): + """读取一帧彩色图像和深度图像""" + pass + + @abstractmethod + def get_camera_intrinsics(self): + """获取彩色图像和深度图像的内参""" + pass \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/datasets/20250121/test_dataset_synchronized.hdf5 b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/datasets/20250121/test_dataset_synchronized.hdf5 new file mode 100644 index 000000000..2e8b55389 Binary files /dev/null and b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/datasets/20250121/test_dataset_synchronized.hdf5 differ diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/opencv.py b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/opencv.py new file mode 100644 index 000000000..7d12cafa5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/opencv.py @@ -0,0 +1,38 @@ +from shadow_camera import base_camera +import cv2 + +class OpenCVCamera(base_camera.BaseCamera): + """基于OpenCV的摄像头类""" + + def __init__(self, device_id=0): + """初始化视频捕获 + + 参数: + device_id: 摄像头设备ID + """ + self.cap = cv2.VideoCapture(device_id) + + def get_frame(self): + """获取当前帧 + + 返回: + frame: 当前帧的图像数据,取不到时返回None + """ + ret, frame = self.cap.read() + return frame if ret else None + + def get_frame_info(self): + """获取当前帧信息 + + 返回: + dict: 帧信息字典 + """ + width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS)) + + return { + 'width': width, + 'height': height, + 'channels': channels + } \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.h5 b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.h5 new file mode 100644 index 000000000..07f26f3dc Binary files /dev/null and b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.h5 differ diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.npz b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.npz new file mode 100644 index 000000000..14eae2b9c Binary files /dev/null and b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/output_data.npz differ diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/raw_data.h5 b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/raw_data.h5 new file mode 100644 index 000000000..7a16090c8 Binary files /dev/null and b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/raw_data.h5 differ diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/realsense.py b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/realsense.py new file mode 100644 index 000000000..39235dbb1 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/realsense.py @@ -0,0 +1,280 @@ +import time +import logging +import numpy as np +import pyrealsense2 as rs +import base_camera + +# 设置日志配置 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class RealSenseCamera(base_camera.BaseCamera): + """Intel RealSense相机类""" + + def __init__(self, serial_num=None, is_depth_frame=False): + """ + 初始化相机对象 + :param serial_num: 相机序列号,默认为None + """ + super().__init__() + self._color_resolution = [640, 480] + self._depth_resolution = [640, 480] + self._color_frames_rate = 30 + self._depth_frames_rate = 15 + self.timestamp = 0 + self.color_timestamp = 0 + self.depth_timestamp = 0 + self._colorizer = rs.colorizer() + self._config = rs.config() + self.is_depth_frame = is_depth_frame + self.camera_on = False + self.serial_num = serial_num + + def get_serial_num(self): + serial_num = {} + context = rs.context() + devices = context.query_devices() # 获取所有设备 + if len(context.devices) > 0: + for i, device in enumerate(devices): + serial_num[i] = device.get_info(rs.camera_info.serial_number) + + logging.info(f"Detected serial numbers: {serial_num}") + return serial_num + + def _set_config(self): + if self.serial_num is not None: + logging.info(f"Setting device with serial number: {self.serial_num}") + self._config.enable_device(self.serial_num) + + self._config.enable_stream( + rs.stream.color, + self._color_resolution[0], + self._color_resolution[1], + rs.format.rgb8, + self._color_frames_rate, + ) + if self.is_depth_frame: + self._config.enable_stream( + rs.stream.depth, + self._depth_resolution[0], + self._depth_resolution[1], + rs.format.z16, + self._depth_frames_rate, + ) + + def start_camera(self): + """ + 启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参 + """ + self._pipeline = rs.pipeline() + if self.is_depth_frame: + self.point_cloud = rs.pointcloud() + self._align = rs.align(rs.stream.color) + self._set_config() + + self.profile = self._pipeline.start(self._config) + + if self.is_depth_frame: + self._depth_intrinsics = ( + self.profile.get_stream(rs.stream.depth) + .as_video_stream_profile() + .get_intrinsics() + ) + + self._color_intrinsics = ( + self.profile.get_stream(rs.stream.color) + .as_video_stream_profile() + .get_intrinsics() + ) + self.camera_on = True + logging.info("Camera started successfully") + logging.info( + f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}" + ) + logging.info( + f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}" + ) + + def stop_camera(self): + """ + 停止相机 + """ + self._pipeline.stop() + self.camera_on = False + logging.info("Camera stopped") + + def set_resolution(self, color_resolution, depth_resolution): + self._color_resolution = color_resolution + self._depth_resolution = depth_resolution + logging.info( + "Optional color resolution:" + "[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]" + "[848, 480] [960, 540] [1280, 720] [1920, 1080]" + ) + logging.info( + "Optional depth resolution:" + "[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]" + "[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]" + ) + logging.info(f"Set color resolution to: {color_resolution}") + logging.info(f"Set depth resolution to: {depth_resolution}") + + def set_frame_rate(self, color_fps, depth_fps): + self._color_frames_rate = color_fps + self._depth_frames_rate = depth_fps + logging.info("Optional color fps: 6 15 30 60 ") + logging.info("Optional depth fps: 6 15 30 60 90 100 300") + logging.info(f"Set color FPS to: {color_fps}") + logging.info(f"Set depth FPS to: {depth_fps}") + + # TODO: 调节白平衡进行补偿 + # def set_exposure(self, exposure): + + def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False): + """ + 读取一帧彩色图像和深度图像 + :return: 彩色图像和深度图像的NumPy数组 + """ + while not self.camera_on: + time.sleep(0.5) + color_image = None + depth_image = None + colorized_depth = None + point_cloud = None + try: + frames = self._pipeline.wait_for_frames() + if is_color: + color_frame = frames.get_color_frame() + color_image = np.asanyarray(color_frame.get_data()) + else: + color_image = None + + if is_depth: + depth_frame = frames.get_depth_frame() + depth_image = np.asanyarray(depth_frame.get_data()) + else: + depth_image = None + + colorized_depth = ( + np.asanyarray(self._colorizer.colorize(depth_frame).get_data()) + if is_colorized_depth + else None + ) + point_cloud = ( + np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices()) + if is_point_cloud + else None + ) + # 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color + self.color_timestamp = color_frame.get_timestamp() + if self.is_depth_frame: + self.depth_timestamp = depth_frame.get_timestamp() + + except Exception as e: + logging.warning(e) + if "Frame didn't arrive within 5000" in str(e): + logging.warning("Frame didn't arrive within 5000ms, resetting device") + self.stop_camera() + self.start_camera() + + return color_image, depth_image, colorized_depth, point_cloud + + def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False): + """ + 读取一帧对齐的彩色图像和深度图像 + :return: 彩色图像和深度图像的NumPy数组 + """ + while not self.camera_on: + time.sleep(0.5) + try: + frames = self._pipeline.wait_for_frames() + aligned_frames = self._align.process(frames) + aligned_color_frame = aligned_frames.get_color_frame() + self._aligned_depth_frame = aligned_frames.get_depth_frame() + + color_image = np.asanyarray(aligned_color_frame.get_data()) + depth_image = np.asanyarray(self._aligned_depth_frame.get_data()) + colorized_depth = ( + np.asanyarray( + self._colorizer.colorize(self._aligned_depth_frame).get_data() + ) + if is_colorized_depth + else None + ) + + if is_point_cloud: + points = self.point_cloud.calculate(self._aligned_depth_frame) + # 将元组数据转换为 NumPy 数组 + point_cloud = np.array( + [[point[0], point[1], point[2]] for point in points.get_vertices()] + ) + else: + point_cloud = None + + # 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color + self.timestamp = aligned_color_frame.get_timestamp() + + return color_image, depth_image, colorized_depth, point_cloud + + except Exception as e: + if "Frame didn't arrive within 5000" in str(e): + logging.warning("Frame didn't arrive within 5000ms, resetting device") + self.stop_camera() + self.start_camera() + # device = self.profile.get_device() + # device.hardware_reset() + + def get_camera_intrinsics(self): + """ + 获取彩色图像和深度图像的内参信息 + :return: 彩色图像和深度图像的内参信息 + """ + # 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs + logging.info("Getting camera intrinsics") + logging.info( + "Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs" + ) + return self._color_intrinsics, self._depth_intrinsics + + def get_3d_camera_coordinate(self, depth_pixel, align=False): + """ + 获取深度相机坐标系下的三维坐标 + :param depth_pixel:深度像素坐标 + :param align: 是否对齐 + + :return: 深度值和相机坐标 + """ + if not hasattr(self, "_aligned_depth_frame"): + raise AttributeError( + "Aligned depth frame not set. Call read_align_frame() first." + ) + + distance = self._aligned_depth_frame.get_distance( + depth_pixel[0], depth_pixel[1] + ) + intrinsics = self._color_intrinsics if align else self._depth_intrinsics + camera_coordinate = rs.rs2_deproject_pixel_to_point( + intrinsics, depth_pixel, distance + ) + return distance, camera_coordinate + + +if __name__ == "__main__": + + camera = RealSenseCamera(is_depth_frame=False) + camera.get_serial_num() + camera.start_camera() + # camera.set_frame_rate(60, 60) + color_image, depth_image, colorized_depth, point_cloud = camera.read_frame() + camera.stop_camera() + logging.info(f"Color image shape: {color_image.shape}") + # logging.info(f"Depth image shape: {depth_image.shape}") + # logging.info(f"Colorized depth image shape: {colorized_depth.shape}") + # logging.info(f"Point cloud shape: {point_cloud.shape}") + logging.info(f"Color timestamp: {camera.timestamp}") + # logging.info(f"Depth timestamp: {camera.depth_timestamp}") + logging.info(f"Color timestamp: {camera.color_timestamp}") + # logging.info(f"Depth timestamp: {camera.depth_timestamp}") + logging.info("Test passed") diff --git a/realman_src/realman_aloha/shadow_camera/src/shadow_camera/test.py b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/test.py new file mode 100644 index 000000000..6e28d7c2f --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/src/shadow_camera/test.py @@ -0,0 +1,101 @@ +import pyrealsense2 as rs +import numpy as np +import h5py +import time +import threading +import keyboard # 用于监听键盘输入 + +# 全局变量 +is_recording = False # 标志位,控制录制状态 +color_images = [] # 存储彩色图像 +depth_images = [] # 存储深度图像 +timestamps = [] # 存储时间戳 + +# 配置D435相机 +def configure_camera(): + pipeline = rs.pipeline() + config = rs.config() + config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流 + config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流 + pipeline.start(config) + return pipeline + +# 监听键盘输入,控制录制状态 +def listen_for_keyboard(): + global is_recording + while True: + if keyboard.is_pressed('s'): # 按下 's' 开始录制 + is_recording = True + print("Recording started.") + time.sleep(0.5) # 防止重复触发 + elif keyboard.is_pressed('q'): # 按下 'q' 停止录制 + is_recording = False + print("Recording stopped.") + time.sleep(0.5) # 防止重复触发 + elif keyboard.is_pressed('e'): # 按下 'e' 退出程序 + print("Exiting program.") + exit() + time.sleep(0.1) + +# 采集图像数据 +def capture_frames(pipeline): + global is_recording, color_images, depth_images, timestamps + try: + while True: + if is_recording: + frames = pipeline.wait_for_frames() + color_frame = frames.get_color_frame() + depth_frame = frames.get_depth_frame() + + if not color_frame or not depth_frame: + continue + + # 获取当前时间戳 + timestamp = time.time() + + # 将图像转换为numpy数组 + color_image = np.asanyarray(color_frame.get_data()) + depth_image = np.asanyarray(depth_frame.get_data()) + + # 存储数据 + color_images.append(color_image) + depth_images.append(depth_image) + timestamps.append(timestamp) + + print(f"Captured frame at {timestamp}") + + else: + time.sleep(0.1) # 如果未录制,等待一段时间 + + finally: + pipeline.stop() + +# 保存为HDF5文件 +def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"): + with h5py.File(filename, "w") as f: + f.create_dataset("color_images", data=np.array(color_images), compression="gzip") + f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip") + f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip") + print(f"Data saved to {filename}") + +# 主函数 +def main(): + global is_recording, color_images, depth_images, timestamps + + # 启动键盘监听线程 + keyboard_thread = threading.Thread(target=listen_for_keyboard) + keyboard_thread.daemon = True + keyboard_thread.start() + + # 配置相机 + pipeline = configure_camera() + + # 开始采集图像 + capture_frames(pipeline) + + # 录制结束后保存数据 + if color_images and depth_images and timestamps: + save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_camera/test/test_camera.py b/realman_src/realman_aloha/shadow_camera/test/test_camera.py new file mode 100644 index 000000000..1fa5df5e9 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/test/test_camera.py @@ -0,0 +1,152 @@ +import os +import cv2 +import time +import numpy as np +from os import path +import pyrealsense2 as rs +from shadow_camera import realsense +import logging + + + +def test_camera(): + camera = realsense.RealSenseCamera('241122071186') + camera.start_camera() + + while True: + # result = camera.read_align_frame() + # if result is None: + # print('is None') + # continue + # start_time = time.time() + color_image, depth_image, colorized_depth, vtx = camera.read_frame() + color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + + print(f"color_image: {color_image.shape}") + # print(f"Time: {end_time - start_time}") + cv2.imshow("bgr_image", color_image) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + camera.stop_camera() + + +def test_get_serial_num(): + camera = realsense.RealSenseCamera() + device = camera.get_serial_num() + + +class CameraCapture: + def __init__(self, camera_serial_num=None, save_dir="./save"): + self._camera_serial_num = camera_serial_num + self._color_save_dir = path.join(save_dir, "color") + self._depth_save_dir = path.join(save_dir, "depth") + os.makedirs(save_dir, exist_ok=True) + os.makedirs(self._color_save_dir, exist_ok=True) + os.makedirs(self._depth_save_dir, exist_ok=True) + + def get_serial_num(self): + self._camera_serial_num = {} + camera_names = ["left", "right", "head", "table"] + context = rs.context() + devices = context.query_devices() # 获取所有设备 + if len(context.devices) > 0: + for i, device in enumerate(devices): + self._camera_serial_num[camera_names[i]] = device.get_info( + rs.camera_info.serial_number + ) + print(self._camera_serial_num) + + return self._camera_serial_num + + def start_camera(self): + if self._camera_serial_num is None: + self.get_serial_num() + self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"]) + self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"]) + self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"]) + + self._camera_left.start_camera() + self._camera_right.start_camera() + self._camera_head.start_camera() + + def stop_camera(self): + self._camera_left.stop_camera() + self._camera_right.stop_camera() + self._camera_head.stop_camera() + + def _save_datas(self, timestamp, color_image, depth_image, camera_name): + color_filename = path.join( + self._color_save_dir, f"{timestamp}" + camera_name + ".jpg" + ) + depth_filename = path.join( + self._depth_save_dir, f"{timestamp}" + camera_name + ".png" + ) + cv2.imwrite(color_filename, color_image) + cv2.imwrite(depth_filename, depth_image) + + def capture_images(self): + while True: + ( + color_image_left, + depth_image_left, + _, + _, + ) = self._camera_left.read_align_frame() + ( + color_image_right, + depth_image_right, + _, + _, + ) = self._camera_right.read_align_frame() + ( + color_image_head, + depth_image_head, + _, + point_cloud3, + ) = self._camera_head.read_align_frame() + + bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR) + bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR) + bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR) + + timestamp = time.time() * 1000 + + cv2.imshow("Camera left", bgr_color_image_left) + cv2.imshow("Camera right", bgr_color_image_right) + cv2.imshow("Camera head", bgr_color_image_head) + + # self._save_datas( + # timestamp, bgr_color_image_left, depth_image_left, "left" + # ) + # self._save_datas( + # timestamp, bgr_color_image_right, depth_image_right, "right" + # ) + # self._save_datas( + # timestamp, bgr_color_image_head, depth_image_head, "head" + # ) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + cv2.destroyAllWindows() + + +if __name__ == "__main__": + #test_camera() + test_get_serial_num() + """ + 输入相机序列号制定左右相机: + dict:{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'} + 保存路径: + str:./save + 输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save + """ + + # capture = CameraCapture() + # capture.get_serial_num() + # test_get_serial_num() + + # capture.start_camera() + # capture.capture_images() + # capture.stop_camera() diff --git a/realman_src/realman_aloha/shadow_camera/test/test_realsense.py b/realman_src/realman_aloha/shadow_camera/test/test_realsense.py new file mode 100644 index 000000000..bfb80f576 --- /dev/null +++ b/realman_src/realman_aloha/shadow_camera/test/test_realsense.py @@ -0,0 +1,71 @@ +import pytest +import pyrealsense2 as rs +from shadow_camera.realsense import RealSenseCamera + + +class TestRealSenseCamera: + @pytest.fixture(autouse=True) + def setup_camera(self): + self.camera = RealSenseCamera() + + def test_get_serial_num(self): + serial_nums = self.camera.get_serial_num() + assert isinstance(serial_nums, dict) + assert len(serial_nums) > 0 + + def test_start_stop_camera(self): + self.camera.start_camera() + assert self.camera.camera_on is True + self.camera.stop_camera() + assert self.camera.camera_on is False + + def test_set_resolution(self): + color_resolution = [1280, 720] + depth_resolution = [1280, 720] + self.camera.set_resolution(color_resolution, depth_resolution) + assert self.camera._color_resolution == color_resolution + assert self.camera._depth_resolution == depth_resolution + + def test_set_frame_rate(self): + color_fps = 60 + depth_fps = 60 + self.camera.set_frame_rate(color_fps, depth_fps) + assert self.camera._color_frames_rate == color_fps + assert self.camera._depth_frames_rate == depth_fps + + def test_read_frame(self): + self.camera.start_camera() + color_image, depth_image, colorized_depth, point_cloud = ( + self.camera.read_frame() + ) + assert color_image is not None + assert depth_image is not None + self.camera.stop_camera() + + def test_read_align_frame(self): + self.camera.start_camera() + color_image, depth_image, colorized_depth, point_cloud = ( + self.camera.read_align_frame() + ) + assert color_image is not None + assert depth_image is not None + self.camera.stop_camera() + + def test_get_camera_intrinsics(self): + self.camera.start_camera() + color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics() + assert color_intrinsics is not None + assert depth_intrinsics is not None + self.camera.stop_camera() + + def test_get_3d_camera_coordinate(self): + self.camera.start_camera() + # 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置 + self.camera.read_align_frame() + depth_pixel = [320, 240] + distance, camera_coordinate = self.camera.get_3d_camera_coordinate( + depth_pixel, align=True + ) + assert distance > 0 + assert len(camera_coordinate) == 3 + self.camera.stop_camera() diff --git a/realman_src/realman_aloha/shadow_rm_act/.gitignore b/realman_src/realman_aloha/shadow_rm_act/.gitignore new file mode 100644 index 000000000..bfa287bf9 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +build/ +devel/ +dist/ +data/ +.catkin_workspace +*.pyc +*.pyo +*.pt +.vscode/ diff --git a/realman_src/realman_aloha/shadow_rm_act/README.md b/realman_src/realman_aloha/shadow_rm_act/README.md new file mode 100644 index 000000000..2c6f726dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/README.md @@ -0,0 +1,89 @@ +# ACT: Action Chunking with Transformers + +### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) +TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus. + +#### Project Website: https://tonyzhaozh.github.io/aloha/ + +This repo contains the implementation of ACT, together with 2 simulated environments: +Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real. +For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha). + +### Updates: +You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link). + + +### Repo Structure +- ``imitate_episodes.py`` Train and Evaluate ACT +- ``policy.py`` An adaptor for ACT policy +- ``detr`` Model definitions of ACT, modified from DETR +- ``sim_env.py`` Mujoco + DM_Control environments with joint space control +- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control +- ``scripted_policy.py`` Scripted policies for sim environments +- ``constants.py`` Constants shared across files +- ``utils.py`` Utils such as data loading and helper functions +- ``visualize_episodes.py`` Save videos from a .hdf5 dataset + + +### Installation + + conda create -n aloha python=3.8.10 + conda activate aloha + pip install torchvision + pip install torch + pip install pyquaternion + pip install pyyaml + pip install rospkg + pip install pexpect + pip install mujoco==2.3.7 + pip install dm_control==1.0.14 + pip install opencv-python + pip install matplotlib + pip install einops + pip install packaging + pip install h5py + pip install ipython + cd act/detr && pip install -e . + +### Example Usages + +To set up a new terminal, run: + + conda activate aloha + cd + +### Simulated experiments + +We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``. +To generated 50 episodes of scripted data, run: + + python3 record_sim_episodes.py \ + --task_name sim_transfer_cube_scripted \ + --dataset_dir \ + --num_episodes 50 + +To can add the flag ``--onscreen_render`` to see real-time rendering. +To visualize the episode after it is collected, run + + python3 visualize_episodes.py --dataset_dir --episode_idx 0 + +To train ACT: + + # Transfer Cube task + python3 imitate_episodes.py \ + --task_name sim_transfer_cube_scripted \ + --ckpt_dir \ + --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \ + --num_epochs 2000 --lr 1e-5 \ + --seed 0 + + +To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint. +The success rate should be around 90% for transfer cube, and around 50% for insertion. +To enable temporal ensembling, add flag ``--temporal_agg``. +Videos will be saved to ```` for each rollout. +You can also add ``--onscreen_render`` to see real-time rendering during evaluation. + +For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued. +Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info. + diff --git a/realman_src/realman_aloha/shadow_rm_act/config/config.yaml b/realman_src/realman_aloha/shadow_rm_act/config/config.yaml new file mode 100644 index 000000000..fa6edee19 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/config/config.yaml @@ -0,0 +1,74 @@ +robot_env: { + # TODO change the path to the correct one + rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml', + rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml', + arm_axis: 6, + head_camera: '215222076892', + bottom_camera: '215222076981', + left_camera: '152122078151', + right_camera: '152122073489', + # init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9], + # init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9] + init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617], + init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876] +} +dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103' +checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data' +# checkpoint_name: 'policy_best.ckpt' +checkpoint_name: 'policy_9500.ckpt' +state_dim: 14 +save_episode: True +num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量 +real_robot: True +policy_class: 'ACT' +onscreen_render: False +camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'] +episode_len: 300 #episode 的最大长度(时间步数)。 +task_name: 'aloha_01_11.28' +temporal_agg: False #是否使用时间聚合 +batch_size: 8 #训练期间每批的样本数。 +seed: 1000 #随机种子。 +chunk_size: 30 #用于处理序列的块大小 +eval_every: 1 #每隔 eval_every 步评估一次模型。 +num_steps: 10000 #训练的总步数。 +validate_every: 1 #每隔 validate_every 步验证一次模型。 +save_every: 500 #每隔 save_every 步保存一次检查点。 +load_pretrain: False #是否加载预训练模型。 +resume_ckpt_path: +name_filter: # TODO +skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。 +stats_dir: +sample_weights: +train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证) + +policy_config: { + hidden_dim: 512, # Size of the embeddings (dimension of the transformer) + state_dim: 14, # Dimension of the state + position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features + lr_backbone: 1.0e-5, + masks: False, # If true, the model masks the non-visible pixels + backbone: 'resnet18', + dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5) + dropout: 0.1, # Dropout applied in the transformer + nheads: 8, + dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks + enc_layers: 4, # Number of encoding layers in the transformer + dec_layers: 7, # Number of decoding layers in the transformer + pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward + num_queries: 30, + camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'], + vq: False, + vq_class: none, + vq_dim: 64, + action_dim: 14, + no_encoder: False, + lr: 1.0e-5, + weight_decay: 1.0e-4, + kl_weight: 10, + + # lr_drop: 200, + # clip_max_norm: 0.1, +} + + + diff --git a/realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py b/realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py new file mode 100644 index 000000000..e3a460677 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py @@ -0,0 +1,267 @@ +import numpy as np +import collections +import os + +from constants import DT, XML_DIR, START_ARM_POSE +from constants import PUPPET_GRIPPER_POSITION_CLOSE +from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN + +from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite import base + +import IPython +e = IPython.embed + + +def make_ee_sim_env(task_name): + """ + Environment for simulated robot bi-manual manipulation, with end-effector control. + Action space: [left_arm_pose (7), # position and quaternion for end effector + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_pose (7), # position and quaternion for end effector + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' + """ + if 'sim_transfer_cube' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + task = TransferCubeEETask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + elif 'sim_insertion' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + task = InsertionEETask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + else: + raise NotImplementedError + return env + +class BimanualViperXEETask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + a_len = len(action) // 2 + action_left = action[:a_len] + action_right = action[a_len:] + + # set mocap position and quat + # left + np.copyto(physics.data.mocap_pos[0], action_left[:3]) + np.copyto(physics.data.mocap_quat[0], action_left[3:7]) + # right + np.copyto(physics.data.mocap_pos[1], action_right[:3]) + np.copyto(physics.data.mocap_quat[1], action_right[3:7]) + + # set gripper + g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7]) + g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7]) + np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) + + def initialize_robots(self, physics): + # reset joint position + physics.named.data.qpos[:16] = START_ARM_POSE + + # reset mocap to align with end effector + # to obtain these numbers: + # (1) make an ee_sim env and reset to the same start_pose + # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] + # get env._physics.named.data.xquat['vx300s_left/gripper_link'] + # repeat the same for right side + np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) + np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) + # right + np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) + np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) + + # reset gripper control + close_gripper_control = np.array([ + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + ]) + np.copyto(physics.data.ctrl, close_gripper_control) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])] + right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])] + right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + # note: it is important to do .copy() + obs = collections.OrderedDict() + obs['qpos'] = self.get_qpos(physics) + obs['qvel'] = self.get_qvel(physics) + obs['env_state'] = self.get_env_state(physics) + obs['images'] = dict() + obs['images']['top'] = physics.render(height=480, width=640, camera_id='top') + obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle') + obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') + # used in scripted policy to obtain starting pose + obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy() + obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy() + + # used when replaying joint trajectory + obs['gripper_ctrl'] = physics.data.ctrl.copy() + return obs + + def get_reward(self, physics): + raise NotImplementedError + + +class TransferCubeEETask(BimanualViperXEETask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize box position + cube_pose = sample_box_pose() + box_start_idx = physics.model.name2id('red_box_joint', 'joint') + np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionEETask(BimanualViperXEETask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize peg and socket position + peg_pose, socket_pose = sample_insertion_pose() + id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky + + peg_start_id = physics.model.name2id('red_peg_joint', 'joint') + peg_start_idx = id2index(peg_start_id) + np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) + # print(f"randomized cube position to {cube_position}") + + socket_start_id = physics.model.name2id('blue_socket_joint', 'joint') + socket_start_idx = id2index(socket_start_id) + np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ("socket-1", "table") in all_contact_pairs or \ + ("socket-2", "table") in all_contact_pairs or \ + ("socket-3", "table") in all_contact_pairs or \ + ("socket-4", "table") in all_contact_pairs + peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \ + ("red_peg", "socket-2") in all_contact_pairs or \ + ("red_peg", "socket-3") in all_contact_pairs or \ + ("red_peg", "socket-4") in all_contact_pairs + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/realman_src/realman_aloha/shadow_rm_act/pyproject.toml b/realman_src/realman_aloha/shadow_rm_act/pyproject.toml new file mode 100644 index 000000000..eafd8d644 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "shadow_act" +version = "0.1.0" +description = "Embodied data, ACT and other methods; training and verification function packages" +readme = "README.md" +authors = ["Shadow "] +license = "MIT" +# include = ["realman_vision/pytransform/_pytransform.so",] +classifiers = [ + "Operating System :: POSIX :: Linux amd64", + "Programming Language :: Python :: 3.10", +] + +[tool.poetry.dependencies] +python = ">=3.9" +wandb = ">=0.18.0" +einops = ">=0.8.0" + + + +[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。 +pytest = ">=8.3" +black = ">=24.10.0" + + + +[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。 + + +[tool.poetry.group.dev.dependencies] + + + +[build-system] +requires = ["poetry-core>=1.8.4"] +build-backend = "poetry.core.masonry.api" diff --git a/realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py b/realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py new file mode 100644 index 000000000..253fdea1c --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py @@ -0,0 +1,189 @@ +import time +import os +import numpy as np +import argparse +import matplotlib.pyplot as plt +import h5py + +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS +from ee_sim_env import make_ee_sim_env +from sim_env import make_sim_env, BOX_POSE +from scripted_policy import PickAndTransferPolicy, InsertionPolicy + +import IPython +e = IPython.embed + + +def main(args): + """ + Generate demonstration data in simulation. + First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory. + Replace the gripper joint positions with the commanded joint position. + Replay this joint trajectory (as action sequence) in sim_env, and record all observations. + Save this episode of data, and continue to next episode of data collection. + """ + + task_name = args['task_name'] + dataset_dir = args['dataset_dir'] + num_episodes = args['num_episodes'] + onscreen_render = args['onscreen_render'] + inject_noise = False + render_cam_name = 'angle' + + if not os.path.isdir(dataset_dir): + os.makedirs(dataset_dir, exist_ok=True) + + episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] + camera_names = SIM_TASK_CONFIGS[task_name]['camera_names'] + if task_name == 'sim_transfer_cube_scripted': + policy_cls = PickAndTransferPolicy + elif task_name == 'sim_insertion_scripted': + policy_cls = InsertionPolicy + else: + raise NotImplementedError + + success = [] + for episode_idx in range(num_episodes): + print(f'{episode_idx=}') + print('Rollout out EE space scripted policy') + # setup the environment + env = make_ee_sim_env(task_name) + ts = env.reset() + episode = [ts] + policy = policy_cls(inject_noise) + # setup plotting + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images'][render_cam_name]) + plt.ion() + for step in range(episode_len): + action = policy(ts) + ts = env.step(action) + episode.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images'][render_cam_name]) + plt.pause(0.002) + plt.close() + + episode_return = np.sum([ts.reward for ts in episode[1:]]) + episode_max_reward = np.max([ts.reward for ts in episode[1:]]) + if episode_max_reward == env.task.max_reward: + print(f"{episode_idx=} Successful, {episode_return=}") + else: + print(f"{episode_idx=} Failed") + + joint_traj = [ts.observation['qpos'] for ts in episode] + # replace gripper pose with gripper control + gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode] + for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): + left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) + right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) + joint[6] = left_ctrl + joint[6+7] = right_ctrl + + subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0 + + # clear unused variables + del env + del episode + del policy + + # setup the environment + print('Replaying joint commands') + env = make_sim_env(task_name) + BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env + ts = env.reset() + + episode_replay = [ts] + # setup plotting + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images'][render_cam_name]) + plt.ion() + for t in range(len(joint_traj)): # note: this will increase episode length by 1 + action = joint_traj[t] + ts = env.step(action) + episode_replay.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images'][render_cam_name]) + plt.pause(0.02) + + episode_return = np.sum([ts.reward for ts in episode_replay[1:]]) + episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]]) + if episode_max_reward == env.task.max_reward: + success.append(1) + print(f"{episode_idx=} Successful, {episode_return=}") + else: + success.append(0) + print(f"{episode_idx=} Failed") + + plt.close() + + """ + For each timestep: + observations + - images + - each_cam_name (480, 640, 3) 'uint8' + - qpos (14,) 'float64' + - qvel (14,) 'float64' + + action (14,) 'float64' + """ + + data_dict = { + '/observations/qpos': [], + '/observations/qvel': [], + '/action': [], + } + for cam_name in camera_names: + data_dict[f'/observations/images/{cam_name}'] = [] + + # because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps + # truncate here to be consistent + joint_traj = joint_traj[:-1] + episode_replay = episode_replay[:-1] + + # len(joint_traj) i.e. actions: max_timesteps + # len(episode_replay) i.e. time steps: max_timesteps + 1 + max_timesteps = len(joint_traj) + while joint_traj: + action = joint_traj.pop(0) + ts = episode_replay.pop(0) + data_dict['/observations/qpos'].append(ts.observation['qpos']) + data_dict['/observations/qvel'].append(ts.observation['qvel']) + data_dict['/action'].append(action) + for cam_name in camera_names: + data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) + + # HDF5 + t0 = time.time() + dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}') + with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root: + root.attrs['sim'] = True + obs = root.create_group('observations') + image = obs.create_group('images') + for cam_name in camera_names: + _ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8', + chunks=(1, 480, 640, 3), ) + # compression='gzip',compression_opts=2,) + # compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False) + qpos = obs.create_dataset('qpos', (max_timesteps, 14)) + qvel = obs.create_dataset('qvel', (max_timesteps, 14)) + action = root.create_dataset('action', (max_timesteps, 14)) + + for name, array in data_dict.items(): + root[name][...] = array + print(f'Saving: {time.time() - t0:.1f} secs\n') + + print(f'Saved to {dataset_dir}') + print(f'Success: {np.sum(success)} / {len(success)}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) + parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True) + parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False) + parser.add_argument('--onscreen_render', action='store_true') + + main(vars(parser.parse_args())) + diff --git a/realman_src/realman_aloha/shadow_rm_act/scripted_policy.py b/realman_src/realman_aloha/shadow_rm_act/scripted_policy.py new file mode 100644 index 000000000..4fd8f0007 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/scripted_policy.py @@ -0,0 +1,194 @@ +import numpy as np +import matplotlib.pyplot as plt +from pyquaternion import Quaternion + +from constants import SIM_TASK_CONFIGS +from ee_sim_env import make_ee_sim_env + +import IPython +e = IPython.embed + + +class BasePolicy: + def __init__(self, inject_noise=False): + self.inject_noise = inject_noise + self.step_count = 0 + self.left_trajectory = None + self.right_trajectory = None + + def generate_trajectory(self, ts_first): + raise NotImplementedError + + @staticmethod + def interpolate(curr_waypoint, next_waypoint, t): + t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"]) + curr_xyz = curr_waypoint['xyz'] + curr_quat = curr_waypoint['quat'] + curr_grip = curr_waypoint['gripper'] + next_xyz = next_waypoint['xyz'] + next_quat = next_waypoint['quat'] + next_grip = next_waypoint['gripper'] + xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac + quat = curr_quat + (next_quat - curr_quat) * t_frac + gripper = curr_grip + (next_grip - curr_grip) * t_frac + return xyz, quat, gripper + + def __call__(self, ts): + # generate trajectory at first timestep, then open-loop execution + if self.step_count == 0: + self.generate_trajectory(ts) + + # obtain left and right waypoints + if self.left_trajectory[0]['t'] == self.step_count: + self.curr_left_waypoint = self.left_trajectory.pop(0) + next_left_waypoint = self.left_trajectory[0] + + if self.right_trajectory[0]['t'] == self.step_count: + self.curr_right_waypoint = self.right_trajectory.pop(0) + next_right_waypoint = self.right_trajectory[0] + + # interpolate between waypoints to obtain current pose and gripper command + left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count) + right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count) + + # Inject noise + if self.inject_noise: + scale = 0.01 + left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape) + right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape) + + action_left = np.concatenate([left_xyz, left_quat, [left_gripper]]) + action_right = np.concatenate([right_xyz, right_quat, [right_gripper]]) + + self.step_count += 1 + return np.concatenate([action_left, action_right]) + + +class PickAndTransferPolicy(BasePolicy): + + def generate_trajectory(self, ts_first): + init_mocap_pose_right = ts_first.observation['mocap_pose_right'] + init_mocap_pose_left = ts_first.observation['mocap_pose_left'] + + box_info = np.array(ts_first.observation['env_state']) + box_xyz = box_info[:3] + box_quat = box_info[3:] + # print(f"Generate trajectory for {box_xyz=}") + + gripper_pick_quat = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60) + + meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90) + + meet_xyz = np.array([0, 0.5, 0.25]) + + self.left_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep + {"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position + {"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position + {"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper + {"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left + {"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay + ] + + self.right_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep + {"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube + {"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down + {"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper + {"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position + {"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position + {"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper + {"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right + {"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay + ] + + +class InsertionPolicy(BasePolicy): + + def generate_trajectory(self, ts_first): + init_mocap_pose_right = ts_first.observation['mocap_pose_right'] + init_mocap_pose_left = ts_first.observation['mocap_pose_left'] + + peg_info = np.array(ts_first.observation['env_state'])[:7] + peg_xyz = peg_info[:3] + peg_quat = peg_info[3:] + + socket_info = np.array(ts_first.observation['env_state'])[7:] + socket_xyz = socket_info[:3] + socket_quat = socket_info[3:] + + gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60) + + gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60) + + meet_xyz = np.array([0, 0.5, 0.15]) + lift_right = 0.00715 + + self.left_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep + {"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube + {"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down + {"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper + {"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position + {"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion + {"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion + ] + + self.right_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep + {"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube + {"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down + {"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper + {"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position + {"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion + {"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion + + ] + + +def test_policy(task_name): + # example rolling out pick_and_transfer policy + onscreen_render = True + inject_noise = False + + # setup the environment + episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] + if 'sim_transfer_cube' in task_name: + env = make_ee_sim_env('sim_transfer_cube') + elif 'sim_insertion' in task_name: + env = make_ee_sim_env('sim_insertion') + else: + raise NotImplementedError + + for episode_idx in range(2): + ts = env.reset() + episode = [ts] + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images']['angle']) + plt.ion() + + policy = PickAndTransferPolicy(inject_noise) + for step in range(episode_len): + action = policy(ts) + ts = env.step(action) + episode.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images']['angle']) + plt.pause(0.02) + plt.close() + + episode_return = np.sum([ts.reward for ts in episode[1:]]) + if episode_return > 0: + print(f"{episode_idx=} Successful, {episode_return=}") + else: + print(f"{episode_idx=} Failed") + + +if __name__ == '__main__': + test_task_name = 'sim_transfer_cube_scripted' + test_policy(test_task_name) + diff --git a/realman_src/realman_aloha/shadow_rm_act/sim_env.py b/realman_src/realman_aloha/shadow_rm_act/sim_env.py new file mode 100644 index 000000000..b79b935b1 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/sim_env.py @@ -0,0 +1,278 @@ +import numpy as np +import os +import collections +import matplotlib.pyplot as plt +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite import base + +from constants import DT, XML_DIR, START_ARM_POSE +from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN +from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN + +import IPython +e = IPython.embed + +BOX_POSE = [None] # to be changed from outside + +def make_sim_env(task_name): + """ + Environment for simulated robot bi-manual manipulation, with joint position control + Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' + """ + if 'sim_transfer_cube' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + task = TransferCubeTask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + elif 'sim_insertion' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + task = InsertionTask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + else: + raise NotImplementedError + return env + +class BimanualViperXTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + left_arm_action = action[:6] + right_arm_action = action[7:7+6] + normalized_left_gripper_action = action[6] + normalized_right_gripper_action = action[7+6] + + left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action) + right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action) + + full_left_gripper_action = [left_gripper_action, -left_gripper_action] + full_right_gripper_action = [right_gripper_action, -right_gripper_action] + + env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]) + super().before_step(env_action, physics) + return + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])] + right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])] + right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + obs = collections.OrderedDict() + obs['qpos'] = self.get_qpos(physics) + obs['qvel'] = self.get_qvel(physics) + obs['env_state'] = self.get_env_state(physics) + obs['images'] = dict() + obs['images']['top'] = physics.render(height=480, width=640, camera_id='top') + obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle') + obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') + + return obs + + def get_reward(self, physics): + # return whether left gripper is holding the box + raise NotImplementedError + + +class TransferCubeTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7:] = BOX_POSE[0] + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ("socket-1", "table") in all_contact_pairs or \ + ("socket-2", "table") in all_contact_pairs or \ + ("socket-3", "table") in all_contact_pairs or \ + ("socket-4", "table") in all_contact_pairs + peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \ + ("red_peg", "socket-2") in all_contact_pairs or \ + ("red_peg", "socket-3") in all_contact_pairs or \ + ("red_peg", "socket-4") in all_contact_pairs + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward + + +def get_action(master_bot_left, master_bot_right): + action = np.zeros(14) + # arm action + action[:6] = master_bot_left.dxl.joint_states.position[:6] + action[7:7+6] = master_bot_right.dxl.joint_states.position[:6] + # gripper action + left_gripper_pos = master_bot_left.dxl.joint_states.position[7] + right_gripper_pos = master_bot_right.dxl.joint_states.position[7] + normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos) + normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos) + action[6] = normalized_left_pos + action[7+6] = normalized_right_pos + return action + +def test_sim_teleop(): + """ Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """ + from interbotix_xs_modules.arm import InterbotixManipulatorXS + + BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0] + + # source of data + master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_left', init_node=True) + master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_right', init_node=False) + + # setup the environment + env = make_sim_env('sim_transfer_cube') + ts = env.reset() + episode = [ts] + # setup plotting + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images']['angle']) + plt.ion() + + for t in range(1000): + action = get_action(master_bot_left, master_bot_right) + ts = env.step(action) + episode.append(ts) + + plt_img.set_data(ts.observation['images']['angle']) + plt.pause(0.02) + + +if __name__ == '__main__': + test_sim_teleop() + diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/__init__.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/__init__.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/rm_act_eval.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/rm_act_eval.py new file mode 100644 index 000000000..b89b523a0 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/eval/rm_act_eval.py @@ -0,0 +1,575 @@ +import os +import time +import yaml +import torch +import pickle +import dm_env +import logging +import collections +import numpy as np +import tracemalloc +from einops import rearrange +import matplotlib.pyplot as plt +from torchvision import transforms +from shadow_rm_robot.realman_arm import RmArm +from shadow_camera.realsense import RealSenseCamera +from shadow_act.models.latent_model import Latent_Model_Transformer +from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy +from shadow_act.utils.utils import set_seed + + +# 配置logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +# # 隐藏h5py的警告Creating converter from 7 to 5 +# logging.getLogger("h5py").setLevel(logging.WARNING) + + +class RmActEvaluator: + def __init__(self, config, save_episode=True, num_rollouts=50): + """ + 初始化Evaluator类 + + Args: + config (dict): 配置字典 + checkpoint_name (str): 检查点名称 + save_episode (bool): 是否保存每个episode + num_rollouts (int): 滚动次数 + """ + self.config = config + self._seed = config["seed"] + self.robot_env = config["robot_env"] + self.checkpoint_dir = config["checkpoint_dir"] + self.checkpoint_name = config["checkpoint_name"] + self.save_episode = save_episode + self.num_rollouts = num_rollouts + self.state_dim = config["state_dim"] + self.real_robot = config["real_robot"] + self.policy_class = config["policy_class"] + self.onscreen_render = config["onscreen_render"] + self.camera_names = config["camera_names"] + self.max_timesteps = config["episode_len"] + self.task_name = config["task_name"] + self.temporal_agg = config["temporal_agg"] + self.onscreen_cam = "angle" + self.policy_config = config["policy_config"] + self.vq = config["policy_config"]["vq"] + # self.actuator_config = config["actuator_config"] + # self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None + self.stats = None + self.env = None + self.env_max_reward = 0 + + def _make_policy(self, policy_class, policy_config): + """ + 根据策略类和配置创建策略对象 + + Args: + policy_class (str): 策略类名称 + policy_config (dict): 策略配置字典 + + Returns: + policy: 创建的策略对象 + """ + if policy_class == "ACT": + return ACTPolicy(policy_config) + elif policy_class == "CNNMLP": + return CNNMLPPolicy(policy_config) + elif policy_class == "Diffusion": + return DiffusionPolicy(policy_config) + else: + raise NotImplementedError(f"Policy class {policy_class} is not implemented") + + def load_policy_and_stats(self): + """ + 加载策略和统计数据 + """ + checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name) + logging.info(f"Loading policy from: {checkpoint_path}") + self.policy = self._make_policy(self.policy_class, self.policy_config) + # 加载模型并设置为评估模式 + self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True)) + self.policy.cuda() + self.policy.eval() + + if self.vq: + vq_dim = self.config["policy_config"]["vq_dim"] + vq_class = self.config["policy_config"]["vq_class"] + self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class) + latent_model_checkpoint_path = os.path.join( + self.checkpoint_dir, "latent_model_last.ckpt" + ) + self.latent_model.deserialize(torch.load(latent_model_checkpoint_path)) + self.latent_model.eval() + self.latent_model.cuda() + logging.info( + f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}" + ) + else: + logging.info(f"Loaded: {checkpoint_path}") + + stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl") + with open(stats_path, "rb") as f: + self.stats = pickle.load(f) + + def pre_process(self, state_qpos): + """ + 预处理状态位置 + + Args: + state_qpos (np.array): 状态位置数组 + + Returns: + np.array: 预处理后的状态位置 + """ + if self.policy_class == "Diffusion": + return ((state_qpos + 1) / 2) * ( + self.stats["action_max"] - self.stats["action_min"] + ) + self.stats["action_min"] + # 标准化处理,均值为 0,标准差为 1 + + return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"] + + def post_process(self, action): + """ + 后处理动作 + + Args: + action (np.array): 动作数组 + + Returns: + np.array: 后处理后的动作 + """ + # 反标准化处理 + return action * self.stats["action_std"] + self.stats["action_mean"] + + def get_image_torch(self, timestep, camera_names, random_crop_resize=False): + """ + 获取图像 + + Args: + timestep (object): 时间步对象 + camera_names (list): 相机名称列表 + random_crop_resize (bool): 是否随机裁剪和调整大小 + + Returns: + torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width) + """ + current_images = [] + for cam_name in camera_names: + current_image = rearrange( + timestep.observation["images"][cam_name], "h w c -> c h w" + ) + current_images.append(current_image) + current_image = np.stack(current_images, axis=0) + current_image = ( + torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0) + ) + + if random_crop_resize: + logging.info("Random crop resize is used!") + original_size = current_image.shape[-2:] + ratio = 0.95 + current_image = current_image[ + ..., + int(original_size[0] * (1 - ratio) / 2) : int( + original_size[0] * (1 + ratio) / 2 + ), + int(original_size[1] * (1 - ratio) / 2) : int( + original_size[1] * (1 + ratio) / 2 + ), + ] + current_image = current_image.squeeze(0) + resize_transform = transforms.Resize(original_size, antialias=True) + current_image = resize_transform(current_image) + current_image = current_image.unsqueeze(0) + + return current_image + + def load_environment(self): + """ + 加载环境 + """ + if self.real_robot: + self.env = DeviceAloha(self.robot_env) + self.env_max_reward = 0 + else: + from sim_env import make_sim_env + + self.env = make_sim_env(self.task_name) + self.env_max_reward = self.env.task.max_reward + + def get_auto_index(self, checkpoint_dir): + max_idx = 1000 + for i in range(max_idx + 1): + if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")): + return i + raise Exception(f"Error getting auto index, or more than {max_idx} episodes") + + def evaluate(self, checkpoint_name=None): + """ + 评估策略 + + Returns: + tuple: 成功率和平均回报 + """ + if checkpoint_name is not None: + self.checkpoint_name = checkpoint_name + set_seed(self._seed) # np与torch的随机种子 + self.load_policy_and_stats() + self.load_environment() + + query_frequency = self.policy_config["num_queries"] + + # 时间聚合时,每个时间步只有1个查询 + if self.temporal_agg: + query_frequency = 1 + num_queries = self.policy_config["num_queries"] + + # # 真实机器人时,基础延迟为13??? + # if self.real_robot: + # BASE_DELAY = 13 + # # query_frequency -= BASE_DELAY + + max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks + episode_returns = [] + highest_rewards = [] + + for rollout_id in range(self.num_rollouts): + + timestep = self.env.reset() + + if self.onscreen_render: + # TODO 画图 + pass + if self.temporal_agg: + all_time_actions = torch.zeros( + [max_timesteps, max_timesteps + num_queries, self.state_dim] + ).cuda() + qpos_history_raw = np.zeros((max_timesteps, self.state_dim)) + rewards = [] + + with torch.inference_mode(): + time_0 = time.time() + DT = 1 / 30 + culmulated_delay = 0 + for t in range(max_timesteps): + time_1 = time.time() + if self.onscreen_render: + # TODO 显示图像 + pass + # process previous timestep to get qpos and image_list + obs = timestep.observation + qpos_numpy = np.array(obs["qpos"]) + qpos_history_raw[t] = qpos_numpy + qpos = self.pre_process(qpos_numpy) + qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) + + logging.info(f"t{t}") + + if t % query_frequency == 0: + current_image = self.get_image_torch( + timestep, + self.camera_names, + random_crop_resize=( + self.config["policy_class"] == "Diffusion" + ), + ) + + if t == 0: + # 网络预热 + for _ in range(10): + self.policy(qpos, current_image) + logging.info("Network warm up done") + + if self.config["policy_class"] == "ACT": + if t % query_frequency == 0: + if self.vq: + if rollout_id == 0: + for _ in range(10): + vq_sample = self.latent_model.generate( + 1, temperature=1, x=None + ) + logging.info( + torch.nonzero(vq_sample[0])[:, 1] + .cpu() + .numpy() + ) + vq_sample = self.latent_model.generate( + 1, temperature=1, x=None + ) + all_actions = self.policy( + qpos, current_image, vq_sample=vq_sample + ) + else: + all_actions = self.policy(qpos, current_image) + # if self.real_robot: + # all_actions = torch.cat( + # [ + # all_actions[:, :-BASE_DELAY, :-2], + # all_actions[:, BASE_DELAY:, -2:], + # ], + # dim=2, + # ) + if self.temporal_agg: + all_time_actions[[t], t : t + num_queries] = all_actions + actions_for_curr_step = all_time_actions[:, t] + actions_populated = torch.all( + actions_for_curr_step != 0, axis=1 + ) + actions_for_curr_step = actions_for_curr_step[ + actions_populated + ] + k = 0.01 + exp_weights = np.exp( + -k * np.arange(len(actions_for_curr_step)) + ) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = ( + torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + ) + raw_action = (actions_for_curr_step * exp_weights).sum( + dim=0, keepdim=True + ) + else: + raw_action = all_actions[:, t % query_frequency] + elif self.config["policy_class"] == "Diffusion": + if t % query_frequency == 0: + all_actions = self.policy(qpos, current_image) + # if self.real_robot: + # all_actions = torch.cat( + # [ + # all_actions[:, :-BASE_DELAY, :-2], + # all_actions[:, BASE_DELAY:, -2:], + # ], + # dim=2, + # ) + raw_action = all_actions[:, t % query_frequency] + elif self.config["policy_class"] == "CNNMLP": + raw_action = self.policy(qpos, current_image) + all_actions = raw_action.unsqueeze(0) + else: + raise NotImplementedError + + ### post-process actions + raw_action = raw_action.squeeze(0).cpu().numpy() + action = self.post_process(raw_action) + + ### step the environment + if self.real_robot: + logging.info(f" action = {action}") + timestep = self.env.step(action) + + rewards.append(timestep.reward) + duration = time.time() - time_1 + sleep_time = max(0, DT - duration) + time.sleep(sleep_time) + if duration >= DT: + culmulated_delay += duration - DT + logging.warning( + f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s" + ) + + logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}") + plt.close() + + if self.real_robot: + log_id = self.get_auto_index(self.checkpoint_dir) + np.save( + os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"), + qpos_history_raw, + ) + plt.figure(figsize=(10, 20)) + for i in range(self.state_dim): + plt.subplot(self.state_dim, 1, i + 1) + plt.plot(qpos_history_raw[:, i]) + if i != self.state_dim - 1: + plt.xticks([]) + plt.tight_layout() + plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png")) + plt.close() + + rewards = np.array(rewards) + episode_return = np.sum(rewards[rewards != None]) + episode_returns.append(episode_return) + episode_highest_reward = np.max(rewards) + highest_rewards.append(episode_highest_reward) + logging.info( + f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}" + ) + + success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward) + avg_return = np.mean(episode_returns) + summary_str = ( + f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n" + ) + for r in range(self.env_max_reward + 1): + more_or_equal_r = (np.array(highest_rewards) >= r).sum() + more_or_equal_r_rate = more_or_equal_r / self.num_rollouts + summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n" + + logging.info(summary_str) + + result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt" + with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f: + f.write(summary_str) + f.write(repr(episode_returns)) + f.write("\n\n") + f.write(repr(highest_rewards)) + + return success_rate, avg_return + + +class DeviceAloha: + def __init__(self, aloha_config): + """ + 初始化设备 + + Args: + device_name (str): 设备名称 + """ + config_left_arm = aloha_config["rm_left_arm"] + config_right_arm = aloha_config["rm_right_arm"] + config_head_camera = aloha_config["head_camera"] + config_bottom_camera = aloha_config["bottom_camera"] + config_left_camera = aloha_config["left_camera"] + config_right_camera = aloha_config["right_camera"] + self.init_left_arm_angle = aloha_config["init_left_arm_angle"] + self.init_right_arm_angle = aloha_config["init_right_arm_angle"] + self.arm_axis = aloha_config["arm_axis"] + self.arm_left = RmArm(config_left_arm) + self.arm_right = RmArm(config_right_arm) + self.camera_left = RealSenseCamera(config_head_camera, False) + self.camera_right = RealSenseCamera(config_bottom_camera, False) + self.camera_bottom = RealSenseCamera(config_left_camera, False) + self.camera_top = RealSenseCamera(config_right_camera, False) + self.camera_left.start_camera() + self.camera_right.start_camera() + self.camera_bottom.start_camera() + self.camera_top.start_camera() + + def close(self): + """ + 关闭摄像头 + """ + self.camera_left.close() + self.camera_right.close() + self.camera_bottom.close() + self.camera_top.close() + + def get_qps(self): + """ + 获取关节角度 + + Returns: + np.array: 关节角度 + """ + left_slave_arm_angle = self.arm_left.get_joint_angle() + left_joint_angles_array = np.array(list(left_slave_arm_angle.values())) + right_slave_arm_angle = self.arm_right.get_joint_angle() + right_joint_angles_array = np.array(list(right_slave_arm_angle.values())) + return np.concatenate([left_joint_angles_array, right_joint_angles_array]) + + def get_qvel(self): + """ + 获取关节速度 + + Returns: + np.array: 关节速度 + """ + left_slave_arm_velocity = self.arm_left.get_joint_velocity() + left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values())) + right_slave_arm_velocity = self.arm_right.get_joint_velocity() + right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values())) + return np.concatenate([left_joint_velocity_array, right_joint_velocity_array]) + + def get_effort(self): + """ + 获取关节力 + + Returns: + np.array: 关节力 + """ + left_slave_arm_effort = self.arm_left.get_joint_effort() + left_joint_effort_array = np.array(list(left_slave_arm_effort.values())) + right_slave_arm_effort = self.arm_right.get_joint_effort() + right_joint_effort_array = np.array(list(right_slave_arm_effort.values())) + return np.concatenate([left_joint_effort_array, right_joint_effort_array]) + + def get_images(self): + """ + 获取图像 + + Returns: + dict: 图像字典 + """ + self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False) + self.bottom_image, _, _, _ = self.camera_bottom.read_frame( + True, False, False, False + ) + self.left_image, _, _, _ = self.camera_left.read_frame( + True, False, False, False + ) + self.right_image, _, _, _ = self.camera_right.read_frame( + True, False, False, False + ) + return { + "cam_high": self.top_image, + "cam_low": self.bottom_image, + "cam_left": self.left_image, + "cam_right": self.right_image, + } + + def get_observation(self): + obs = collections.OrderedDict() + obs["qpos"] = self.get_qps() + obs["qvel"] = self.get_qvel() + obs["effort"] = self.get_effort() + obs["images"] = self.get_images() + return obs + + def reset(self): + logging.info("Resetting the environment") + self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis]) + self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis]) + self.arm_left.set_gripper_position(0) + self.arm_right.set_gripper_position(0) + return dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=0, + discount=None, + observation=self.get_observation(), + ) + + def step(self, target_angle): + self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis]) + self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1]) + self.arm_left.set_gripper_position(target_angle[self.arm_axis]) + self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)]) + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=0, + discount=None, + observation=self.get_observation(), + ) + + +if __name__ == "__main__": + # with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f: + # config = yaml.safe_load(f) + # aloha_config = config["robot_env"] + # device = DeviceAloha(aloha_config) + # device.reset() + # while True: + # init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle]) + # time_step = time.time() + # timestep = device.step(init_angle) + # logging.info(f"Time: {time.time() - time_step}") + # obs = timestep.observation + + with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f: + config = yaml.safe_load(f) + # logging.info(f"Config: {config}") + evaluator = RmActEvaluator(config) + success_rate, avg_return = evaluator.evaluate() diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/__init__.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/backbone.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/backbone.py new file mode 100644 index 000000000..e62bb6eed --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/backbone.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +import torch +import torchvision +from torch import nn +from typing import Dict, List +import torch.nn.functional as F +from .position_encoding import build_position_encoding +from torchvision.models import ResNet18_Weights +from torchvision.models._utils import IntermediateLayerGetter +from shadow_act.utils.misc import NestedTensor, is_main_process + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__( + self, + backbone: nn.Module, + train_backbone: bool, + num_channels: int, + return_interm_layers: bool, + ): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {"layer4": "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__( + self, + name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool, + ): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None, + norm_layer=FrozenBatchNorm2d, +) + # backbone = getattr(torchvision.models, name)( + # replace_stride_with_dilation=[False, False, dilation], + # pretrained=is_main_process(), + # norm_layer=FrozenBatchNorm2d, + # ) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ("resnet18", "resnet34") else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone( + hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation +): + + position_embedding = build_position_encoding( + hidden_dim=hidden_dim, position_embedding_type=position_embedding_type + ) + train_backbone = lr_backbone > 0 + return_interm_layers = masks + backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/detr_vae.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/detr_vae.py new file mode 100644 index 000000000..1b1c04f0e --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/detr_vae.py @@ -0,0 +1,436 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +import torch.nn.functional as F +from shadow_act.models.transformer import Transformer +from .backbone import build_backbone +from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer + +import numpy as np + + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + vq, + vq_class, + vq_dim, + action_dim, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim + self.state_dim, self.action_dim = state_dim, action_dim + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, action_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + action_dim, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear( + action_dim, hidden_dim + ) # project qpos to embedding + if self.vq: + self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim) + else: + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + if self.vq: + self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim) + else: + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + def encode(self, qpos, actions=None, is_pad=None, vq_sample=None): + bs, _ = qpos.shape + if self.encoder is None: + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + probs = binaries = mu = logvar = None + else: + # cvae encoder + is_training = actions is not None # train or val + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj( + actions + ) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute( + 1, 0, 2 + ) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + + if self.vq: + logits = latent_info.reshape( + [*latent_info.shape[:-1], self.vq_class, self.vq_dim] + ) + probs = torch.softmax(logits, dim=-1) + binaries = ( + F.one_hot( + torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze( + -1 + ), + self.vq_dim, + ) + .view(-1, self.vq_class, self.vq_dim) + .float() + ) + binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim) + probs_flat = probs.view(-1, self.vq_class * self.vq_dim) + straigt_through = binaries_flat - probs_flat.detach() + probs_flat + latent_input = self.latent_out_proj(straigt_through) + mu = logvar = None + else: + probs = binaries = None + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + + else: + mu = logvar = binaries = probs = None + if self.vq: + latent_input = self.latent_out_proj( + vq_sample.view(-1, self.vq_class * self.vq_dim) + ) + else: + latent_sample = torch.zeros( + [bs, self.latent_dim], dtype=torch.float32 + ).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + return latent_input, probs, binaries, mu, logvar + + def forward( + self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None + ): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + + latent_input, probs, binaries, mu, logvar = self.encode( + qpos, actions, is_pad, vq_sample + ) + + # cvae decoder + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + # TODO: fix this error + features, pos = self.backbones[0](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar], probs, binaries + + +class CNNMLP(nn.Module): + def __init__(self, backbones, state_dim, camera_names): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.camera_names = camera_names + self.action_head = nn.Linear(1000, state_dim) # TODO add more + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + backbone_down_projs = [] + for backbone in backbones: + down_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, 128, kernel_size=5), + nn.Conv2d(128, 64, kernel_size=5), + nn.Conv2d(64, 32, kernel_size=5), + ) + backbone_down_projs.append(down_proj) + self.backbone_down_projs = nn.ModuleList(backbone_down_projs) + + mlp_in_dim = 768 * len(backbones) + state_dim + self.mlp = mlp( + input_dim=mlp_in_dim, + hidden_dim=1024, + output_dim=self.action_dim, + hidden_depth=2, + ) + else: + raise NotImplementedError + + def forward(self, qpos, image, env_state, actions=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + # Image observation features and position embeddings + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] # not used + all_cam_features.append(self.backbone_down_projs[cam_id](features)) + # flatten everything + flattened_features = [] + for cam_feature in all_cam_features: + flattened_features.append(cam_feature.reshape([bs, -1])) + flattened_features = torch.cat(flattened_features, axis=1) # 768 each + features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + a_hat = self.mlp(features) + return a_hat + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder( + hidden_dim, # 256 + dropout, # 0.1 + nheads, # 8 + dim_feedforward, + num_encoder_layers, # 4 # TODO shared with VAE decoder + normalize_before, # False +): + activation = "relu" + + encoder_layer = TransformerEncoderLayer( + hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build_vae( + hidden_dim, + state_dim, + position_embedding_type, + lr_backbone, + masks, + backbone, + dilation, + dropout, + nheads, + dim_feedforward, + enc_layers, + dec_layers, + pre_norm, + num_queries, + camera_names, + vq, + vq_class, + vq_dim, + action_dim, + no_encoder, +): + # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone( + hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation + ) + backbones.append(backbone) + + transformer = build_transformer( + hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm + ) + + if no_encoder: + encoder = None + else: + encoder = build_encoder( + hidden_dim, + dropout, + nheads, + dim_feedforward, + enc_layers, + pre_norm, + ) + + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + vq, + vq_class, + vq_dim, + action_dim, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + +# TODO +def build_cnnmlp(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + model = CNNMLP( + backbones, + state_dim=state_dim, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/latent_model.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/latent_model.py new file mode 100644 index 000000000..da03c1899 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/latent_model.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +import torch.nn as nn +from torch.nn import functional as F +import torch + +DROPOUT_RATE = 0.1 # 定义 dropout 率 + +# 定义一个因果变压器块 +class Causal_Transformer_Block(nn.Module): + def __init__(self, seq_len, latent_dim, num_head) -> None: + """ + 初始化因果变压器块 + + Args: + seq_len (int): 序列长度 + latent_dim (int): 潜在维度 + num_head (int): 注意力头的数量 + """ + super().__init__() + self.num_head = num_head + self.latent_dim = latent_dim + self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化 + self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制 + self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化 + self.mlp = nn.Sequential( + nn.Linear(latent_dim, 4 * latent_dim), # 全连接层 + nn.GELU(), # GELU 激活函数 + nn.Linear(4 * latent_dim, latent_dim), # 全连接层 + nn.Dropout(DROPOUT_RATE), # Dropout + ) + + # self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码 + + def forward(self, x): + """ + 前向传播 + + Args: + x (torch.Tensor): 输入张量 + + Returns: + torch.Tensor: 输出张量 + """ + # 创建上三角掩码,防止信息泄露 + attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1) + x = self.ln_1(x) # 层归一化 + x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出 + x = self.ln_2(x) # 层归一化 + x = x + self.mlp(x) # 加上 MLP 输出 + + return x + +# 使用自注意力机制而不是 RNN 来建模潜在空间序列 +class Latent_Model_Transformer(nn.Module): + def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None: + """ + 初始化潜在模型变压器 + + Args: + input_dim (int): 输入维度 + output_dim (int): 输出维度 + seq_len (int): 序列长度 + latent_dim (int, optional): 潜在维度,默认值为 256 + num_head (int, optional): 注意力头的数量,默认值为 8 + num_layer (int, optional): 变压器层的数量,默认值为 3 + """ + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.seq_len = seq_len + self.latent_dim = latent_dim + self.num_head = num_head + self.num_layer = num_layer + self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层 + self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入 + self.attention_blocks = nn.Sequential( + nn.Dropout(DROPOUT_RATE), # Dropout + *[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块 + nn.LayerNorm(latent_dim) # 层归一化 + ) + self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层 + + def forward(self, x): + """ + 前向传播 + + Args: + x (torch.Tensor): 输入张量 + + Returns: + torch.Tensor: 输出张量 + """ + x = self.input_layer(x) # 输入层 + x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入 + x = self.attention_blocks(x) # 通过注意力块 + logits = self.output_layer(x) # 输出层 + + return logits + + @torch.no_grad() + def generate(self, n, temperature=0.1, x=None): + """ + 生成序列 + + Args: + n (int): 生成序列的数量 + temperature (float, optional): 采样温度,默认值为 0.1 + x (torch.Tensor, optional): 初始输入张量,默认值为 None + + Returns: + torch.Tensor: 生成的序列 + """ + if x is None: + x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入 + for i in range(self.seq_len): + logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出 + probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布 + samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样 + samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码 + x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中 + + return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入) \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/position_encoding.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/position_encoding.py new file mode 100644 index 000000000..07350d7ea --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/position_encoding.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from shadow_act.utils.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(hidden_dim, position_embedding_type): + N_steps = hidden_dim // 2 + if position_embedding_type in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding_type in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding_type}") + + return position_embedding diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/transformer.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/transformer.py new file mode 100644 index 000000000..8cc7db679 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/models/transformer.py @@ -0,0 +1,424 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + + +class Transformer(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, + memory, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed, + ) + hs = hs.transpose(1, 2) + return hs + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer( + hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm +): + return Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/__init__.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/policy.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/policy.py new file mode 100644 index 000000000..dfcc213dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/network/policy.py @@ -0,0 +1,522 @@ +import torch +import logging +import numpy as np +import torch.nn as nn +from torch.nn import functional as F +import torchvision.transforms as transforms +from shadow_act.models.detr_vae import build_vae, build_cnnmlp + +# from diffusers.training_utils import EMAModel +# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax +# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D + +# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +# from diffusers.schedulers.scheduling_ddim import DDIMScheduler + +# 配置日志记录 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# TODO: 重构DiffusionPolicy类 +class DiffusionPolicy(nn.Module): + def __init__(self, args_override): + """ + 初始化DiffusionPolicy类 + + Args: + args_override (dict): 参数覆盖字典 + """ + super().__init__() + + self.camera_names = args_override["camera_names"] + self.observation_horizon = args_override["observation_horizon"] + self.action_horizon = args_override["action_horizon"] + self.prediction_horizon = args_override["prediction_horizon"] + self.num_inference_timesteps = args_override["num_inference_timesteps"] + self.ema_power = args_override["ema_power"] + self.lr = args_override["lr"] + self.weight_decay = 0 + + self.num_kp = 32 + self.feature_dimension = 64 + self.ac_dim = args_override["action_dim"] + self.obs_dim = self.feature_dimension * len(self.camera_names) + 14 + + backbones = [] + pools = [] + linears = [] + for _ in self.camera_names: + backbones.append( + ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False) + ) + pools.append( + SpatialSoftmax( + input_shape=[512, 15, 20], + num_kp=self.num_kp, + temperature=1.0, + learnable_temperature=False, + noise_std=0.0, + ) + ) + linears.append( + torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension) + ) + backbones = nn.ModuleList(backbones) + pools = nn.ModuleList(pools) + linears = nn.ModuleList(linears) + + backbones = replace_bn_with_gn(backbones) + + noise_pred_net = ConditionalUnet1D( + input_dim=self.ac_dim, + global_cond_dim=self.obs_dim * self.observation_horizon, + ) + + nets = nn.ModuleDict( + { + "policy": nn.ModuleDict( + { + "backbones": backbones, + "pools": pools, + "linears": linears, + "noise_pred_net": noise_pred_net, + } + ) + } + ) + + nets = nets.float().cuda() + ENABLE_EMA = True + if ENABLE_EMA: + ema = EMAModel(model=nets, power=self.ema_power) + else: + ema = None + self.nets = nets + self.ema = ema + + # 设置噪声调度器 + self.noise_scheduler = DDIMScheduler( + num_train_timesteps=50, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + set_alpha_to_one=True, + steps_offset=0, + prediction_type="epsilon", + ) + + n_parameters = sum(p.numel() for p in self.parameters()) + logger.info("number of parameters: %.2fM", n_parameters / 1e6) + + def configure_optimizers(self): + """ + 配置优化器 + + Returns: + optimizer: 配置的优化器 + """ + optimizer = torch.optim.AdamW( + self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + return optimizer + + def __call__(self, qpos, image, actions=None, is_pad=None): + """ + 前向传播函数 + + Args: + qpos (torch.Tensor): 位置张量 + image (torch.Tensor): 图像张量 + actions (torch.Tensor, optional): 动作张量 + is_pad (torch.Tensor, optional): 填充张量 + + Returns: + dict: 损失字典(训练时) + torch.Tensor: 动作张量(推理时) + """ + B = qpos.shape[0] + if actions is not None: # 训练时 + nets = self.nets + all_features = [] + for cam_id in range(len(self.camera_names)): + cam_image = image[:, cam_id] + cam_features = nets["policy"]["backbones"][cam_id](cam_image) + pool_features = nets["policy"]["pools"][cam_id](cam_features) + pool_features = torch.flatten(pool_features, start_dim=1) + out_features = nets["policy"]["linears"][cam_id](pool_features) + all_features.append(out_features) + + obs_cond = torch.cat(all_features + [qpos], dim=1) + + # 为动作添加噪声 + noise = torch.randn(actions.shape, device=obs_cond.device) + + # 为每个数据点采样一个扩散迭代 + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (B,), + device=obs_cond.device, + ).long() + + # 根据每个扩散迭代的噪声幅度向干净动作添加噪声 + noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps) + + # 预测噪声残差 + noise_pred = nets["policy"]["noise_pred_net"]( + noisy_actions, timesteps, global_cond=obs_cond + ) + + # L2损失 + all_l2 = F.mse_loss(noise_pred, noise, reduction="none") + loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean() + + loss_dict = {} + loss_dict["l2_loss"] = loss + loss_dict["loss"] = loss + + if self.training and self.ema is not None: + self.ema.step(nets) + return loss_dict + else: # 推理时 + To = self.observation_horizon + Ta = self.action_horizon + Tp = self.prediction_horizon + action_dim = self.ac_dim + + nets = self.nets + if self.ema is not None: + nets = self.ema.averaged_model + + all_features = [] + for cam_id in range(len(self.camera_names)): + cam_image = image[:, cam_id] + cam_features = nets["policy"]["backbones"][cam_id](cam_image) + pool_features = nets["policy"]["pools"][cam_id](cam_features) + pool_features = torch.flatten(pool_features, start_dim=1) + out_features = nets["policy"]["linears"][cam_id](pool_features) + all_features.append(out_features) + + obs_cond = torch.cat(all_features + [qpos], dim=1) + + # 从高斯噪声初始化动作 + noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device) + naction = noisy_action + + # 初始化调度器 + self.noise_scheduler.set_timesteps(self.num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # 预测噪声 + noise_pred = nets["policy"]["noise_pred_net"]( + sample=naction, timestep=k, global_cond=obs_cond + ) + + # 逆扩散步骤(去除噪声) + naction = self.noise_scheduler.step( + model_output=noise_pred, timestep=k, sample=naction + ).prev_sample + + return naction + + def serialize(self): + """ + 序列化模型 + + Returns: + dict: 模型状态字典 + """ + return { + "nets": self.nets.state_dict(), + "ema": ( + self.ema.averaged_model.state_dict() if self.ema is not None else None + ), + } + + def deserialize(self, model_dict): + """ + 反序列化模型 + + Args: + model_dict (dict): 模型状态字典 + + Returns: + status: 加载状态 + """ + status = self.nets.load_state_dict(model_dict["nets"]) + logger.info("Loaded model") + if model_dict.get("ema", None) is not None: + logger.info("Loaded EMA") + status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"]) + status = [status, status_ema] + return status + + +class ACTPolicy(nn.Module): + def __init__(self, act_config): + """ + 初始化ACTPolicy类 + + Args: + args_override (dict): 参数覆盖字典 + """ + super().__init__() + lr_backbone = act_config["lr_backbone"] + vq = act_config["vq"] + lr = act_config["lr"] + weight_decay = act_config["weight_decay"] + + model = build_vae( + act_config["hidden_dim"], + act_config["state_dim"], + act_config["position_embedding"], + lr_backbone, + act_config["masks"], + act_config["backbone"], + act_config["dilation"], + act_config["dropout"], + act_config["nheads"], + act_config["dim_feedforward"], + act_config["enc_layers"], + act_config["dec_layers"], + act_config["pre_norm"], + act_config["num_queries"], + act_config["camera_names"], + vq, + act_config["vq_class"], + act_config["vq_dim"], + act_config["action_dim"], + act_config["no_encoder"], + ) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": lr_backbone, + }, + ] + self.optimizer = torch.optim.AdamW( + param_dicts, lr=lr, weight_decay=weight_decay + ) + self.model = model # CVAE解码器 + self.kl_weight = act_config["kl_weight"] + self.vq = vq + logger.info(f"KL Weight {self.kl_weight}") + + def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None): + """ + 前向传播函数 + + Args: + qpos (torch.Tensor): 角度张量 + image (torch.Tensor): 图像张量 + actions (torch.Tensor, optional): 动作张量 + is_pad (torch.Tensor, optional): 填充张量 + vq_sample (torch.Tensor, optional): VQ样本 + + Returns: + dict: 损失字典(训练时) + torch.Tensor: 动作张量(推理时) + """ + env_state = None + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + image = normalize(image) + if actions is not None: # 训练时 + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + + loss_dict = dict() + a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model( + qpos, image, env_state, actions, is_pad + ) + if self.vq or self.model.encoder is None: + total_kld = [torch.tensor(0.0)] + else: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + if self.vq: + loss_dict["vq_discrepancy"] = F.l1_loss( + probs, binaries, reduction="mean" + ) + all_l1 = F.l1_loss(actions, a_hat, reduction="none") + l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() + loss_dict["l1"] = l1 + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + return loss_dict + else: # 推理时 + a_hat, _, (_, _), _, _ = self.model( + qpos, image, env_state, vq_sample=vq_sample + ) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + """ + 配置优化器 + + Returns: + optimizer: 配置的优化器 + """ + return self.optimizer + + @torch.no_grad() + def vq_encode(self, qpos, actions, is_pad): + """ + VQ编码 + + Args: + qpos (torch.Tensor): 位置张量 + actions (torch.Tensor): 动作张量 + is_pad (torch.Tensor): 填充张量 + + Returns: + torch.Tensor: 二进制编码 + """ + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + + _, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad) + + return binaries + + def serialize(self): + """ + 序列化模型 + + Returns: + dict: 模型状态字典 + """ + return self.state_dict() + + def deserialize(self, model_dict): + """ + 反序列化模型 + + Args: + model_dict (dict): 模型状态字典 + + Returns: + status: 加载状态 + """ + return self.load_state_dict(model_dict) + + +class CNNMLPPolicy(nn.Module): + def __init__(self, args_override): + """ + 初始化CNNMLPPolicy类 + + Args: + args_override (dict): 参数覆盖字典 + """ + super().__init__() + # parser = argparse.ArgumentParser( + # "DETR training and evaluation script", parents=[get_args_parser()] + # ) + # args = parser.parse_args() + + # for k, v in args_override.items(): + # setattr(args, k, v) + + model = build_cnnmlp(args_override) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args_override.lr_backbone, + }, + ] + self.model = model # 解码器 + self.optimizer = torch.optim.AdamW( + param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay + ) + + def __call__(self, qpos, image, actions=None, is_pad=None): + """ + 前向传播函数 + + Args: + qpos (torch.Tensor): 位置张量 + image (torch.Tensor): 图像张量 + actions (torch.Tensor, optional): 动作张量 + is_pad (torch.Tensor, optional): 填充张量 + + Returns: + dict: 损失字典(训练时) + torch.Tensor: 动作张量(推理时) + """ + env_state = None + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + image = normalize(image) + if actions is not None: # 训练时 + actions = actions[:, 0] + a_hat = self.model(qpos, image, env_state, actions) + mse = F.mse_loss(actions, a_hat) + loss_dict = dict() + loss_dict["mse"] = mse + loss_dict["loss"] = loss_dict["mse"] + return loss_dict + else: # 推理时 + a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样 + return a_hat + + def configure_optimizers(self): + """ + 配置优化器 + + Returns: + optimizer: 配置的优化器 + """ + return self.optimizer + + +def kl_divergence(mu, logvar): + """ + 计算KL散度 + + Args: + mu (torch.Tensor): 均值张量 + logvar (torch.Tensor): 对数方差张量 + + Returns: + tuple: 总KL散度,维度KL散度,均值KL散度 + """ + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/train/rm_act_train.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/train/rm_act_train.py new file mode 100644 index 000000000..dd2faa795 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/train/rm_act_train.py @@ -0,0 +1,245 @@ +import os +import yaml +import pickle +import torch +# import wandb +import logging +import numpy as np +from tqdm import tqdm +from copy import deepcopy +from itertools import repeat +from shadow_act.utils.utils import ( + set_seed, + load_data, + compute_dict_mean, +) +from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy +from shadow_act.eval.rm_act_eval import RmActEvaluator + +# 配置日志记录 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class RmActTrainer: + def __init__(self, config): + """ + 初始化训练器,设置随机种子,加载数据,保存数据统计信息。 + """ + self._config = config + self._num_steps = config["num_steps"] + self._ckpt_dir = config["checkpoint_dir"] + self._state_dim = config["state_dim"] + self._real_robot = config["real_robot"] + self._policy_class = config["policy_class"] + self._onscreen_render = config["onscreen_render"] + self._policy_config = config["policy_config"] + self._camera_names = config["camera_names"] + self._max_timesteps = config["episode_len"] + self._task_name = config["task_name"] + self._temporal_agg = config["temporal_agg"] + self._onscreen_cam = "angle" + self._vq = config["policy_config"]["vq"] + self._batch_size = config["batch_size"] + + self._seed = config["seed"] + self._eval_every = config["eval_every"] + self._validate_every = config["validate_every"] + self._save_every = config["save_every"] + self._load_pretrain = config["load_pretrain"] + self._resume_ckpt_path = config["resume_ckpt_path"] + + if config["name_filter"] is None: + name_filter = lambda n : True + else: + name_filter = config["name_filter"] + + self._eval = RmActEvaluator(config, True, 50) + # 加载数据 + self._train_dataloader, self._val_dataloader, self._stats, _ = load_data( + config["dataset_dir"], + name_filter, + self._camera_names, + self._batch_size, + self._batch_size, + config["chunk_size"], + config["skip_mirrored_data"], + self._load_pretrain, + self._policy_class, + config["stats_dir"], + config["sample_weights"], + config["train_ratio"], + ) + # 保存数据统计信息 + stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl") + with open(stats_path, "wb") as f: + pickle.dump(self._stats, f) + expr_name = self._ckpt_dir.split("/")[-1] + + # wandb.init( + # project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name + # ) + + + def _make_policy(self): + """ + 根据策略类和配置创建策略对象 + """ + if self._policy_class == "ACT": + return ACTPolicy(self._policy_config) + elif self._policy_class == "CNNMLP": + return CNNMLPPolicy(self._policy_config) + elif self._policy_class == "Diffusion": + return DiffusionPolicy(self._policy_config) + else: + raise NotImplementedError(f"Policy class {self._policy_class} is not implemented") + + def _make_optimizer(self): + """ + 根据策略类创建优化器 + """ + if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]: + return self._policy.configure_optimizers() + else: + raise NotImplementedError + + def _forward_pass(self, data): + """ + 前向传播,计算损失 + """ + image_data, qpos_data, action_data, is_pad = data + try: + image_data, qpos_data, action_data, is_pad = ( + image_data.cuda(), + qpos_data.cuda(), + action_data.cuda(), + is_pad.cuda(), + ) + except RuntimeError as e: + logging.error(f"CUDA error: {e}") + raise + return self._policy(qpos_data, image_data, action_data, is_pad) + + def _repeater(self): + """ + 数据加载器的重复器,生成数据 + """ + epoch = 0 + for loader in repeat(self._train_dataloader): + for data in loader: + yield data + logging.info(f"Epoch {epoch} done") + epoch += 1 + + def train(self): + """ + 训练模型,保存最佳模型 + """ + set_seed(self._seed) + self._policy = self._make_policy() + min_val_loss = np.inf + best_ckpt_info = None + + # 加载预训练模型 + if self._load_pretrain: + try: + loading_status = self._policy.deserialize( + torch.load( + os.path.join( + "/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all", + "policy_step_50000_seed_0.ckpt", + ) + ) + ) + logging.info(f"loaded! {loading_status}") + except FileNotFoundError as e: + logging.error(f"Pretrain model not found: {e}") + except Exception as e: + logging.error(f"Error loading pretrain model: {e}") + + # 恢复检查点 + if self._resume_ckpt_path is not None: + try: + loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path)) + logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}") + except FileNotFoundError as e: + logging.error(f"Checkpoint not found: {e}") + except Exception as e: + logging.error(f"Error loading checkpoint: {e}") + + self._policy.cuda() + + self._optimizer = self._make_optimizer() + train_dataloader = self._repeater() # 重复器 + + for step in tqdm(range(self._num_steps + 1)): + # 验证模型 + if step % self._validate_every != 0: + continue + logging.info("validating") + with torch.inference_mode(): + self._policy.eval() + validation_dicts = [] + for batch_idx, data in enumerate(self._val_dataloader): + forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse} + validation_dicts.append(forward_dict) + if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联 + break + + validation_summary = compute_dict_mean(validation_dicts) + epoch_val_loss = validation_summary["loss"] + if epoch_val_loss < min_val_loss: + min_val_loss = epoch_val_loss + best_ckpt_info = ( + step, + min_val_loss, + deepcopy(self._policy.serialize()), + ) + + # wandb记录验证结果 + # for k in list(validation_summary.keys()): + # validation_summary[f"val_{k}"] = validation_summary.pop(k) + + # wandb.log(validation_summary, step=step) + logging.info(f"Val loss: {epoch_val_loss:.5f}") + summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items()) + logging.info(summary_string) + + # 评估模型 + # if (step > 0) and (step % self._eval_every == 0): + # ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt" + # ckpt_path = os.path.join(self._ckpt_dir, ckpt_name) + # torch.save(self._policy.serialize(), ckpt_path) + # success, _ = self._eval.evaluate(ckpt_name) + # wandb.log({"success": success}, step=step) + + # 训练模型 + self._policy.train() + self._optimizer.zero_grad() + data = next(train_dataloader) + forward_dict = self._forward_pass(data) + loss = forward_dict["loss"] + loss.backward() + self._optimizer.step() + # wandb.log(forward_dict, step=step) + + # 保存检查点 + if step % self._save_every == 0: + ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt") + torch.save(self._policy.serialize(), ckpt_path) + + # 保存最后的模型 + ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt") + torch.save(self._policy.serialize(), ckpt_path) + + best_step, min_val_loss, best_state_dict = best_ckpt_info + ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt") + torch.save(best_state_dict, ckpt_path) + logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}") + + return best_ckpt_info + + +if __name__ == "__main__": + with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + trainer = RmActTrainer(config) + trainer.train() diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/__init__.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/box_ops.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/box_ops.py new file mode 100644 index 000000000..9c088e5ba --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/misc.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/misc.py new file mode 100644 index 000000000..dfa9fb5b8 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/misc.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/plot_utils.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/plot_utils.py new file mode 100644 index 000000000..0f24bed0d --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/plot_utils.py @@ -0,0 +1,107 @@ +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if not dir.exists(): + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + # verify log_name exists + fn = Path(dir / log_name) + if not fn.exists(): + print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") + print(f"--> full path of missing log file: {fn}") + return + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame( + np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] + ).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/utils.py b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/utils.py new file mode 100644 index 000000000..a7b0f2f30 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/src/shadow_act/utils/utils.py @@ -0,0 +1,499 @@ +import numpy as np +import torch +import os +import h5py +import pickle +import fnmatch +import cv2 +from time import time +from torch.utils.data import TensorDataset, DataLoader +import torchvision.transforms as transforms + + + + +def flatten_list(l): + return [item for sublist in l for item in sublist] + + +class EpisodicDataset(torch.utils.data.Dataset): + def __init__( + self, + dataset_path_list, + camera_names, + norm_stats, + episode_ids, + episode_len, + chunk_size, + policy_class, + ): + super(EpisodicDataset).__init__() + self.episode_ids = episode_ids + self.dataset_path_list = dataset_path_list + self.camera_names = camera_names + self.norm_stats = norm_stats + self.episode_len = episode_len + self.chunk_size = chunk_size + self.cumulative_len = np.cumsum(self.episode_len) + self.max_episode_len = max(episode_len) + self.policy_class = policy_class + if self.policy_class == "Diffusion": + self.augment_images = True + else: + self.augment_images = False + self.transformations = None + self.__getitem__(0) # initialize self.is_sim and self.transformations + self.is_sim = False + + # def __len__(self): + # return sum(self.episode_len) + + def _locate_transition(self, index): + assert index < self.cumulative_len[-1] + episode_index = np.argmax( + self.cumulative_len > index + ) # argmax returns first True index + start_ts = index - ( + self.cumulative_len[episode_index] - self.episode_len[episode_index] + ) + episode_id = self.episode_ids[episode_index] + return episode_id, start_ts + + def __getitem__(self, index): + episode_id, start_ts = self._locate_transition(index) + dataset_path = self.dataset_path_list[episode_id] + try: + # print(dataset_path) + with h5py.File(dataset_path, "r") as root: + try: # some legacy data does not have this attribute + is_sim = root.attrs["sim"] + except: + is_sim = False + compressed = root.attrs.get("compress", False) + if "/base_action" in root: + base_action = root["/base_action"][()] + base_action = preprocess_base_action(base_action) + action = np.concatenate([root["/action"][()], base_action], axis=-1) + else: + # TODO + action = root["/action"][()] + # dummy_base_action = np.zeros([action.shape[0], 2]) + # action = np.concatenate([action, dummy_base_action], axis=-1) + original_action_shape = action.shape + episode_len = original_action_shape[0] + # get observation at start_ts only + qpos = root["/observations/qpos"][start_ts] + qvel = root["/observations/qvel"][start_ts] + image_dict = dict() + for cam_name in self.camera_names: + image_dict[cam_name] = root[f"/observations/images/{cam_name}"][ + start_ts + ] + + if compressed: + for cam_name in image_dict.keys(): + decompressed_image = cv2.imdecode(image_dict[cam_name], 1) + image_dict[cam_name] = np.array(decompressed_image) + + # get all actions after and including start_ts + if is_sim: + action = action[start_ts:] + action_len = episode_len - start_ts + else: + action = action[ + max(0, start_ts - 1) : + ] # hack, to make timesteps more aligned + action_len = episode_len - max( + 0, start_ts - 1 + ) # hack, to make timesteps more aligned + + # self.is_sim = is_sim + padded_action = np.zeros( + (self.max_episode_len, original_action_shape[1]), dtype=np.float32 + ) + padded_action[:action_len] = action + is_pad = np.zeros(self.max_episode_len) + is_pad[action_len:] = 1 + + padded_action = padded_action[: self.chunk_size] + is_pad = is_pad[: self.chunk_size] + + # new axis for different cameras + all_cam_images = [] + for cam_name in self.camera_names: + all_cam_images.append(image_dict[cam_name]) + all_cam_images = np.stack(all_cam_images, axis=0) + + # construct observations + image_data = torch.from_numpy(all_cam_images) + qpos_data = torch.from_numpy(qpos).float() + action_data = torch.from_numpy(padded_action).float() + is_pad = torch.from_numpy(is_pad).bool() + + # channel last + image_data = torch.einsum("k h w c -> k c h w", image_data) + + # augmentation + if self.transformations is None: + print("Initializing transformations") + original_size = image_data.shape[2:] + ratio = 0.95 + self.transformations = [ + transforms.RandomCrop( + size=[ + int(original_size[0] * ratio), + int(original_size[1] * ratio), + ] + ), + transforms.Resize(original_size, antialias=True), + transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False), + transforms.ColorJitter( + brightness=0.3, contrast=0.4, saturation=0.5 + ), # , hue=0.08) + ] + + if self.augment_images: + for transform in self.transformations: + image_data = transform(image_data) + + # normalize image and change dtype to float + image_data = image_data / 255.0 + + if self.policy_class == "Diffusion": + # normalize to [-1, 1] + action_data = ( + (action_data - self.norm_stats["action_min"]) + / (self.norm_stats["action_max"] - self.norm_stats["action_min"]) + ) * 2 - 1 + else: + # normalize to mean 0 std 1 + action_data = ( + action_data - self.norm_stats["action_mean"] + ) / self.norm_stats["action_std"] + + qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ + "qpos_std" + ] + + except: + print(f"Error loading {dataset_path} in __getitem__") + quit() + + # print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype) + return image_data, qpos_data, action_data, is_pad + + +def get_norm_stats(dataset_path_list): + all_qpos_data = [] + all_action_data = [] + all_episode_len = [] + + for dataset_path in dataset_path_list: + try: + with h5py.File(dataset_path, "r") as root: + qpos = root["/observations/qpos"][()] + qvel = root["/observations/qvel"][()] + if "/base_action" in root: + base_action = root["/base_action"][()] + # base_action = preprocess_base_action(base_action) + # action = np.concatenate([root["/action"][()], base_action], axis=-1) + else: + # TODO + action = root["/action"][()] + # dummy_base_action = np.zeros([action.shape[0], 2]) + # action = np.concatenate([action, dummy_base_action], axis=-1) + except Exception as e: + print(f"Error loading {dataset_path} in get_norm_stats") + print(e) + quit() + all_qpos_data.append(torch.from_numpy(qpos)) + all_action_data.append(torch.from_numpy(action)) + all_episode_len.append(len(qpos)) + all_qpos_data = torch.cat(all_qpos_data, dim=0) + all_action_data = torch.cat(all_action_data, dim=0) + + # normalize action data + action_mean = all_action_data.mean(dim=[0]).float() + action_std = all_action_data.std(dim=[0]).float() + action_std = torch.clip(action_std, 1e-2, np.inf) # clipping + + # normalize qpos data + qpos_mean = all_qpos_data.mean(dim=[0]).float() + qpos_std = all_qpos_data.std(dim=[0]).float() + qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping + + action_min = all_action_data.min(dim=0).values.float() + action_max = all_action_data.max(dim=0).values.float() + + eps = 0.0001 + stats = { + "action_mean": action_mean.numpy(), + "action_std": action_std.numpy(), + "action_min": action_min.numpy() - eps, + "action_max": action_max.numpy() + eps, + "qpos_mean": qpos_mean.numpy(), + "qpos_std": qpos_std.numpy(), + "example_qpos": qpos, + } + + return stats, all_episode_len + + +def find_all_hdf5(dataset_dir, skip_mirrored_data): + hdf5_files = [] + for root, dirs, files in os.walk(dataset_dir): + for filename in fnmatch.filter(files, "*.hdf5"): + if "features" in filename: + continue + if skip_mirrored_data and "mirror" in filename: + continue + hdf5_files.append(os.path.join(root, filename)) + print(f"Found {len(hdf5_files)} hdf5 files") + return hdf5_files + + +def BatchSampler(batch_size, episode_len_l, sample_weights): + sample_probs = ( + np.array(sample_weights) / np.sum(sample_weights) + if sample_weights is not None + else None + ) + # print("BatchSampler", sample_weights) + sum_dataset_len_l = np.cumsum( + [0] + [np.sum(episode_len) for episode_len in episode_len_l] + ) + while True: + batch = [] + for _ in range(batch_size): + episode_idx = np.random.choice(len(episode_len_l), p=sample_probs) + step_idx = np.random.randint( + sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1] + ) + batch.append(step_idx) + yield batch + + +def load_data( + dataset_dir_l, + name_filter, + camera_names, + batch_size_train, + batch_size_val, + chunk_size, + skip_mirrored_data=False, + load_pretrain=False, + policy_class=None, + stats_dir_l=None, + sample_weights=None, + train_ratio=0.99, +): + if type(dataset_dir_l) == str: + dataset_dir_l = [dataset_dir_l] + dataset_path_list_list = [ + find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l + ] + num_episodes_0 = len(dataset_path_list_list[0]) + dataset_path_list = flatten_list(dataset_path_list_list) + + dataset_path_list = [n for n in dataset_path_list if name_filter(n)] + num_episodes_l = [ + len(dataset_path_list) for dataset_path_list in dataset_path_list_list + ] + num_episodes_cumsum = np.cumsum(num_episodes_l) + + # obtain train test split on dataset_dir_l[0] + shuffled_episode_ids_0 = np.random.permutation(num_episodes_0) + train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)] + val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :] + train_episode_ids_l = [train_episode_ids_0] + [ + np.arange(num_episodes) + num_episodes_cumsum[idx] + for idx, num_episodes in enumerate(num_episodes_l[1:]) + ] + val_episode_ids_l = [val_episode_ids_0] + train_episode_ids = np.concatenate(train_episode_ids_l) + val_episode_ids = np.concatenate(val_episode_ids_l) + print( + f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n" + ) + + # obtain normalization stats for qpos and action + # if load_pretrain: + # with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f: + # norm_stats = pickle.load(f) + # print('Loaded pretrain dataset stats') + _, all_episode_len = get_norm_stats(dataset_path_list) + train_episode_len_l = [ + [all_episode_len[i] for i in train_episode_ids] + for train_episode_ids in train_episode_ids_l + ] + val_episode_len_l = [ + [all_episode_len[i] for i in val_episode_ids] + for val_episode_ids in val_episode_ids_l + ] + + train_episode_len = flatten_list(train_episode_len_l) + val_episode_len = flatten_list(val_episode_len_l) + if stats_dir_l is None: + stats_dir_l = dataset_dir_l + elif type(stats_dir_l) == str: + stats_dir_l = [stats_dir_l] + norm_stats, _ = get_norm_stats( + flatten_list( + [find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l] + ) + ) + print(f"Norm stats from: {stats_dir_l}") + + batch_sampler_train = BatchSampler( + batch_size_train, train_episode_len_l, sample_weights + ) + batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None) + + # print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}') + + # construct dataset and dataloader + train_dataset = EpisodicDataset( + dataset_path_list, + camera_names, + norm_stats, + train_episode_ids, + train_episode_len, + chunk_size, + policy_class, + ) + val_dataset = EpisodicDataset( + dataset_path_list, + camera_names, + norm_stats, + val_episode_ids, + val_episode_len, + chunk_size, + policy_class, + ) + train_num_workers = ( + (8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2 + ) + val_num_workers = 8 if train_dataset.augment_images else 2 + print( + f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}" + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=batch_sampler_train, + pin_memory=True, + num_workers=train_num_workers, + prefetch_factor=2, + ) + val_dataloader = DataLoader( + val_dataset, + batch_sampler=batch_sampler_val, + pin_memory=True, + num_workers=val_num_workers, + prefetch_factor=2, + ) + + return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim + + +def calibrate_linear_vel(base_action, c=None): + if c is None: + c = 0.0 # 0.19 + v = base_action[..., 0] + w = base_action[..., 1] + base_action = base_action.copy() + base_action[..., 0] = v - c * w + return base_action + + +def smooth_base_action(base_action): + return np.stack( + [ + np.convolve(base_action[:, i], np.ones(5) / 5, mode="same") + for i in range(base_action.shape[1]) + ], + axis=-1, + ).astype(np.float32) + + +def preprocess_base_action(base_action): + # base_action = calibrate_linear_vel(base_action) + base_action = smooth_base_action(base_action) + + return base_action + + +def postprocess_base_action(base_action): + linear_vel, angular_vel = base_action + linear_vel *= 1.0 + angular_vel *= 1.0 + # angular_vel = 0 + # if np.abs(linear_vel) < 0.05: + # linear_vel = 0 + return np.array([linear_vel, angular_vel]) + + +### env utils + + +def sample_box_pose(): + x_range = [0.0, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + cube_quat = np.array([1, 0, 0, 0]) + return np.concatenate([cube_position, cube_quat]) + + +def sample_insertion_pose(): + # Peg + x_range = [0.1, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + peg_quat = np.array([1, 0, 0, 0]) + peg_pose = np.concatenate([peg_position, peg_quat]) + + # Socket + x_range = [-0.2, -0.1] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + socket_quat = np.array([1, 0, 0, 0]) + socket_pose = np.concatenate([socket_position, socket_quat]) + + return peg_pose, socket_pose + + +### helper functions + + +def compute_dict_mean(epoch_dicts): + result = {k: None for k in epoch_dicts[0]} + num_items = len(epoch_dicts) + for k in result: + value_sum = 0 + for epoch_dict in epoch_dicts: + value_sum += epoch_dict[k] + result[k] = value_sum / num_items + return result + + +def detach_dict(d): + new_d = dict() + for k, v in d.items(): + new_d[k] = v.detach() + return new_d + + +def set_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) diff --git a/realman_src/realman_aloha/shadow_rm_act/test/test_camera.py b/realman_src/realman_aloha/shadow_rm_act/test/test_camera.py new file mode 100644 index 000000000..edef62cc5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/test/test_camera.py @@ -0,0 +1,163 @@ +from shadow_camera.realsense import RealSenseCamera +from shadow_rm_robot.realman_arm import RmArm +import yaml +import time +import multiprocessing +import numpy as np +import collections +import logging +import dm_env +import tracemalloc + + +class DeviceAloha: + def __init__(self, aloha_config): + """ + 初始化设备 + + Args: + device_name (str): 设备名称 + """ + config_left_arm = aloha_config["rm_left_arm"] + config_right_arm = aloha_config["rm_right_arm"] + config_head_camera = aloha_config["head_camera"] + config_bottom_camera = aloha_config["bottom_camera"] + config_left_camera = aloha_config["left_camera"] + config_right_camera = aloha_config["right_camera"] + self.init_left_arm_angle = aloha_config["init_left_arm_angle"] + self.init_right_arm_angle = aloha_config["init_right_arm_angle"] + self.arm_left = RmArm(config_left_arm) + self.arm_right = RmArm(config_right_arm) + self.camera_left = RealSenseCamera(config_head_camera, False) + self.camera_right = RealSenseCamera(config_bottom_camera, False) + self.camera_bottom = RealSenseCamera(config_left_camera, False) + self.camera_top = RealSenseCamera(config_right_camera, False) + self.camera_left.start_camera() + self.camera_right.start_camera() + self.camera_bottom.start_camera() + self.camera_top.start_camera() + + def close(self): + """ + 关闭摄像头 + """ + self.camera_left.close() + self.camera_right.close() + self.camera_bottom.close() + self.camera_top.close() + + def get_qps(self): + """ + 获取关节角度 + + Returns: + np.array: 关节角度 + """ + left_slave_arm_angle = self.arm_left.get_joint_angle() + left_joint_angles_array = np.array(list(left_slave_arm_angle.values())) + right_slave_arm_angle = self.arm_right.get_joint_angle() + right_joint_angles_array = np.array(list(right_slave_arm_angle.values())) + return np.concatenate([left_joint_angles_array, right_joint_angles_array]) + + def get_qvel(self): + """ + 获取关节速度 + + Returns: + np.array: 关节速度 + """ + left_slave_arm_velocity = self.arm_left.get_joint_velocity() + left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values())) + right_slave_arm_velocity = self.arm_right.get_joint_velocity() + right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values())) + return np.concatenate([left_joint_velocity_array, right_joint_velocity_array]) + + def get_effort(self): + """ + 获取关节力 + + Returns: + np.array: 关节力 + """ + left_slave_arm_effort = self.arm_left.get_joint_effort() + left_joint_effort_array = np.array(list(left_slave_arm_effort.values())) + right_slave_arm_effort = self.arm_right.get_joint_effort() + right_joint_effort_array = np.array(list(right_slave_arm_effort.values())) + return np.concatenate([left_joint_effort_array, right_joint_effort_array]) + + def get_images(self): + """ + 获取图像 + + Returns: + dict: 图像字典 + """ + top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False) + bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False) + left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False) + right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False) + return { + "cam_high": top_image, + "cam_low": bottom_image, + "cam_left": left_image, + "cam_right": right_image, + } + + def get_observation(self): + obs = collections.OrderedDict() + obs["qpos"] = self.get_qps() + obs["qvel"] = self.get_qvel() + obs["effort"] = self.get_effort() + obs["images"] = self.get_images() + # self.get_images() + return obs + + def reset(self): + logging.info("Resetting the environment") + _ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6]) + _ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6]) + self.arm_left.set_gripper_position(0) + self.arm_right.set_gripper_position(0) + return dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=0, + discount=None, + observation=self.get_observation(), + ) + + def step(self, target_angle): + self.arm_left.set_joint_canfd_position(target_angle[0:6]) + self.arm_right.set_joint_canfd_position(target_angle[7:13]) + self.arm_left.set_gripper_position(target_angle[6]) + self.arm_right.set_gripper_position(target_angle[13]) + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=0, + discount=None, + observation=self.get_observation(), + ) + + +if __name__ == '__main__': + with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f: + config = yaml.safe_load(f) + aloha_config = config["robot_env"] + device = DeviceAloha(aloha_config) + device.reset() + image_list = [] + tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle]) + while True: + tracemalloc.start() # 启动内存跟踪 + + tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)]) + time_step = time.time() + timestep = device.step(tager_angle) + logging.info(f"Time: {time.time() - time_step}") + image_list.append(timestep.observation["images"]) + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics('lineno') + # del timestep + print("[ Top 10 ]") + for stat in top_stats[:10]: + print(stat) + # logging.info(f"Images: {obs}") diff --git a/realman_src/realman_aloha/shadow_rm_act/test/test_h5.py b/realman_src/realman_aloha/shadow_rm_act/test/test_h5.py new file mode 100644 index 000000000..621c027ed --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/test/test_h5.py @@ -0,0 +1,32 @@ +import os +# import time +import yaml +import torch +import pickle +import dm_env +import logging +import collections +import numpy as np +from einops import rearrange +import matplotlib.pyplot as plt +from torchvision import transforms +from shadow_rm_robot.realman_arm import RmArm +from shadow_camera.realsense import RealSenseCamera +from shadow_act.models.latent_model import Latent_Model_Transformer +from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy +from shadow_act.utils.utils import ( + load_data, + sample_box_pose, + sample_insertion_pose, + compute_dict_mean, + set_seed, + detach_dict, +) +logging.basicConfig(level=logging.DEBUG) + + + + + + +print('daasdas') \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py b/realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py new file mode 100644 index 000000000..4e55e4719 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py @@ -0,0 +1,147 @@ +import os +import numpy as np +import cv2 +import h5py +import argparse + +import matplotlib.pyplot as plt +from constants import DT + +import IPython +e = IPython.embed + +JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] +STATE_NAMES = JOINT_NAMES + ["gripper"] + +def load_hdf5(dataset_dir, dataset_name): + dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5') + if not os.path.isfile(dataset_path): + print(f'Dataset does not exist at \n{dataset_path}\n') + exit() + + with h5py.File(dataset_path, 'r') as root: + is_sim = root.attrs['sim'] + qpos = root['/observations/qpos'][()] + qvel = root['/observations/qvel'][()] + action = root['/action'][()] + image_dict = dict() + for cam_name in root[f'/observations/images/'].keys(): + image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] + + return qpos, qvel, action, image_dict + +def main(args): + dataset_dir = args['dataset_dir'] + episode_idx = args['episode_idx'] + dataset_name = f'episode_{episode_idx}' + + qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name) + save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) + visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) + # visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back + + +def save_videos(video, dt, video_path=None): + if isinstance(video, list): + cam_names = list(video[0].keys()) + h, w, _ = video[0][cam_names[0]].shape + w = w * len(cam_names) + fps = int(1/dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for ts, image_dict in enumerate(video): + images = [] + for cam_name in cam_names: + image = image_dict[cam_name] + image = image[:, :, [2, 1, 0]] # swap B and R channel + images.append(image) + images = np.concatenate(images, axis=1) + out.write(images) + out.release() + print(f'Saved video to: {video_path}') + elif isinstance(video, dict): + cam_names = list(video.keys()) + all_cam_videos = [] + for cam_name in cam_names: + all_cam_videos.append(video[cam_name]) + all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension + + n_frames, h, w, _ = all_cam_videos.shape + fps = int(1 / dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for t in range(n_frames): + image = all_cam_videos[t] + image = image[:, :, [2, 1, 0]] # swap B and R channel + out.write(image) + out.release() + print(f'Saved video to: {video_path}') + + +def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None): + if label_overwrite: + label1, label2 = label_overwrite + else: + label1, label2 = 'State', 'Command' + + qpos = np.array(qpos_list) # ts, dim + command = np.array(command_list) + num_ts, num_dim = qpos.shape + h, w = 2, num_dim + num_figs = num_dim + fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs)) + + # plot joint state + all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES] + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(qpos[:, dim_idx], label=label1) + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + + # plot arm command + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(command[:, dim_idx], label=label2) + ax.legend() + + if ylim: + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.set_ylim(ylim) + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved qpos plot to: {plot_path}') + plt.close() + +def visualize_timestamp(t_list, dataset_path): + plot_path = dataset_path.replace('.pkl', '_timestamp.png') + h, w = 4, 10 + fig, axs = plt.subplots(2, 1, figsize=(w, h*2)) + # process t_list + t_float = [] + for secs, nsecs in t_list: + t_float.append(secs + nsecs * 10E-10) + t_float = np.array(t_float) + + ax = axs[0] + ax.plot(np.arange(len(t_float)), t_float) + ax.set_title(f'Camera frame timestamps') + ax.set_xlabel('timestep') + ax.set_ylabel('time (sec)') + + ax = axs[1] + ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:]) + ax.set_title(f'dt') + ax.set_xlabel('timestep') + ax.set_ylabel('time (sec)') + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved timestamp plot to: {plot_path}') + plt.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True) + parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False) + main(vars(parser.parse_args())) diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.gitignore b/realman_src/realman_aloha/shadow_rm_aloha/.gitignore new file mode 100644 index 000000000..bfa287bf9 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +build/ +devel/ +dist/ +data/ +.catkin_workspace +*.pyc +*.pyo +*.pt +.vscode/ diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore b/realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore new file mode 100644 index 000000000..50d9d22a7 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/.name b/realman_src/realman_aloha/shadow_rm_aloha/.idea/.name new file mode 100644 index 000000000..d7c912694 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/.name @@ -0,0 +1 @@ +aloha_data_synchronizer.py \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml b/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 000000000..3cbb08ee5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,17 @@ + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml b/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml b/realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml new file mode 100644 index 000000000..10210b412 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml b/realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml new file mode 100644 index 000000000..523608c52 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml b/realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml new file mode 100644 index 000000000..2946dc0d1 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/README.md b/realman_src/realman_aloha/shadow_rm_aloha/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/data_synchronizer.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/data_synchronizer.yaml new file mode 100644 index 000000000..bea8d10cb --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/data_synchronizer.yaml @@ -0,0 +1,38 @@ +dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset' +dataset_name: 'episode' +max_timesteps: 500 +state_dim: 14 +overwrite: False +arm_axis: 6 +camera_names: + - 'cam_high' + - 'cam_low' + - 'cam_left' + - 'cam_right' +ros_topics: + camera_left: '/camera_left/rgb/image_raw' + camera_right: '/camera_right/rgb/image_raw' + camera_bottom: '/camera_bottom/rgb/image_raw' + camera_head: '/camera_head/rgb/image_raw' + left_master_arm: '/left_master_arm_joint_states' + left_slave_arm: '/left_slave_arm_joint_states' + right_master_arm: '/right_master_arm_joint_states' + right_slave_arm: '/right_slave_arm_joint_states' + left_aloha_state: '/left_slave_arm_aloha_state' + right_aloha_state: '/right_slave_arm_aloha_state' +robot_env: { + # TODO change the path to the correct one + rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml', + rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml', + arm_axis: 6, + head_camera: '241122071186', + bottom_camera: '152122078546', + left_camera: '150622070125', + right_camera: '151222072576', + init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0], + init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0] + # init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0], + # init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0] + # init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617], + # init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876] +} \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/rm_left_arm.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/rm_left_arm.yaml new file mode 100644 index 000000000..f7c044a24 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/rm_left_arm.yaml @@ -0,0 +1,9 @@ +arm_ip: "192.168.1.18" +arm_port: 8080 +arm_axis: 6 +local_ip: "192.168.1.101" +local_port: 8089 +# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75 +arm_ki: [7, 7, 7, 3, 3, 3] # rm65 +get_vel: True +get_torque: True \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/rm_right_arm.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/rm_right_arm.yaml new file mode 100644 index 000000000..c028478e5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/rm_right_arm.yaml @@ -0,0 +1,9 @@ +arm_ip: "192.168.1.19" +arm_port: 8080 +arm_axis: 6 +local_ip: "192.168.1.101" +local_port: 8090 +# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75 +arm_ki: [7, 7, 7, 3, 3, 3] # rm65 +get_vel: True +get_torque: True \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/servo_left_arm.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/servo_left_arm.yaml new file mode 100644 index 000000000..de8946bf4 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/servo_left_arm.yaml @@ -0,0 +1,4 @@ +port: /dev/ttyUSB1 +baudrate: 460800 +hex_data: "55 AA 02 00 00 67" +arm_axis: 6 diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/servo_right_arm.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/servo_right_arm.yaml new file mode 100644 index 000000000..218a08397 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/servo_right_arm.yaml @@ -0,0 +1,4 @@ +port: /dev/ttyUSB0 +baudrate: 460800 +hex_data: "55 AA 02 00 00 67" +arm_axis: 6 diff --git a/realman_src/realman_aloha/shadow_rm_aloha/config/vis_data_path.yaml b/realman_src/realman_aloha/shadow_rm_aloha/config/vis_data_path.yaml new file mode 100644 index 000000000..2afb40d41 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/config/vis_data_path.yaml @@ -0,0 +1,6 @@ +dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102' +dataset_name: 'episode' +episode_idx: 1 +FPS: 30 +# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints +joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml b/realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml new file mode 100644 index 000000000..815c599a5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml @@ -0,0 +1,39 @@ +[tool.poetry] +name = "shadow_rm_aloha" +version = "0.1.1" +description = "aloha package, use D435 and Realman robot arm to build aloha to collect data" +readme = "README.md" +authors = ["Shadow "] +license = "MIT" +# include = ["realman_vision/pytransform/_pytransform.so",] +classifiers = [ + "Operating System :: POSIX :: Linux amd64", + "Programming Language :: Python :: 3.10", +] + +[tool.poetry.dependencies] +python = ">=3.10" +matplotlib = ">=3.9.2" +h5py = ">=3.12.1" +# rospy = ">=1.17.0" +# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" } +# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" } + + + +[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。 +pytest = ">=8.3" +black = ">=24.10.0" + + + +[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。 + + +[tool.poetry.group.dev.dependencies] + + + +[build-system] +requires = ["poetry-core>=1.8.4"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/CMakeLists.txt b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/CMakeLists.txt new file mode 100644 index 000000000..fd5c5ba68 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.0.2) +project(shadow_rm_aloha) + +find_package(catkin REQUIRED COMPONENTS + rospy + sensor_msgs + cv_bridge + image_transport + std_msgs + message_generation +) + +add_service_files( + FILES + GetArmStatus.srv + GetImage.srv + MoveArm.srv +) + +generate_messages( + DEPENDENCIES + sensor_msgs + std_msgs +) + +catkin_package( + CATKIN_DEPENDS message_runtime rospy std_msgs +) + +include_directories( + ${catkin_INCLUDE_DIRS} +) + +install(PROGRAMS + arm_node/slave_arm_publisher.py + arm_node/master_arm_publisher.py + arm_node/slave_arm_pub_sub.py + camera_node/camera_publisher.py + data_sub_process/aloha_data_synchronizer.py + data_sub_process/aloha_data_collect.py + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/__init__.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/__init__.py new file mode 100644 index 000000000..541f859dc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/master_arm_publisher.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/master_arm_publisher.py new file mode 100644 index 000000000..8a4e6fe89 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/master_arm_publisher.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import rospy +import logging +from shadow_rm_robot.servo_robotic_arm import ServoArm +from sensor_msgs.msg import JointState + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +class MasterArmPublisher: + def __init__(self): + rospy.init_node("master_arm_publisher", anonymous=True) + arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml") + hz = rospy.get_param("~hz", 250) + self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states") + + self.arm = ServoArm(arm_config) + self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1) + self.rate = rospy.Rate(hz) # 30 Hz + + def publish_joint_states(self): + while not rospy.is_shutdown(): + joint_state = JointState() + joint_pos = self.arm.get_joint_actions() + + joint_state.header.stamp = rospy.Time.now() + joint_state.name = list(joint_pos.keys()) + joint_state.position = list(joint_pos.values()) + joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改) + joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改) + + # rospy.loginfo(f"{self.joint_states_topic}: {joint_state}") + self.publisher.publish(joint_state) + self.rate.sleep() + +if __name__ == "__main__": + try: + arm_publisher = MasterArmPublisher() + arm_publisher.publish_joint_states() + except rospy.ROSInterruptException: + pass \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_pub_sub.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_pub_sub.py new file mode 100644 index 000000000..6356b1181 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_pub_sub.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +import rospy +import logging +from shadow_rm_robot.realman_arm import RmArm +from sensor_msgs.msg import JointState +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +class SlaveArmPublisher: + def __init__(self): + rospy.init_node("slave_arm_publisher", anonymous=True) + arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml") + hz = rospy.get_param("~hz", 250) + joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states") + joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions") + self.arm = RmArm(arm_config) + self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1) + self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback) + self.rate = rospy.Rate(hz) + + def publish_joint_states(self): + while not rospy.is_shutdown(): + joint_state = JointState() + data = self.arm.get_integrate_data() + # data = self.arm.get_arm_data() + joint_state.header.stamp = rospy.Time.now() + joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] + # joint_state.position = data["joint_angle"] + joint_state.position = data['arm_angle'] + + # joint_state.position = list(data["arm_angle"]) + # joint_state.velocity = list(data["arm_velocity"]) + # joint_state.effort = list(data["arm_torque"]) + # rospy.loginfo(f"joint_states_topic: {joint_state}") + self.publisher.publish(joint_state) + self.rate.sleep() + + def callback(self, data): + # rospy.loginfo(f"Received joint_states_topic: {data}") + start_time = rospy.Time.now() + if data is None: + return + if data.name == ["joint_canfd"]: + self.arm.set_joint_canfd_position(data.position[0:6]) + elif data.name == ["joint_j"]: + self.arm.set_joint_position(data.position[0:6]) + + # self.arm.set_gripper_position(data.position[6]) + end_time = rospy.Time.now() + time_cost_ms = (end_time - start_time).to_sec() * 1000 + rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}") + + + +if __name__ == "__main__": + try: + arm_publisher = SlaveArmPublisher() + arm_publisher.publish_joint_states() + except rospy.ROSInterruptException: + pass \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_publisher.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_publisher.py new file mode 100644 index 000000000..414aa3627 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/arm_node/slave_arm_publisher.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import rospy +import logging +from shadow_rm_robot.realman_arm import RmArm +from sensor_msgs.msg import JointState +from std_msgs.msg import Int32MultiArray +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +class SlaveArmPublisher: + def __init__(self): + rospy.init_node("slave_arm_publisher", anonymous=True) + arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml") + hz = rospy.get_param("~hz", 250) + joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states") + aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state") + self.arm = RmArm(arm_config) + self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1) + self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1) + self.rate = rospy.Rate(hz) + + def publish_joint_states(self): + while not rospy.is_shutdown(): + joint_state = JointState() + data = self.arm.get_integrate_data() + joint_state.header.stamp = rospy.Time.now() + joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"] + joint_state.position = list(data["arm_angle"]) + joint_state.velocity = list(data["arm_velocity"]) + joint_state.effort = list(data["arm_torque"]) + # rospy.loginfo(f"joint_states_topic: {joint_state}") + self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values())) + self.publisher.publish(joint_state) + self.rate.sleep() + +if __name__ == "__main__": + try: + arm_publisher = SlaveArmPublisher() + arm_publisher.publish_joint_states() + except rospy.ROSInterruptException: + pass \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/camera_node/camera_publisher.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/camera_node/camera_publisher.py new file mode 100644 index 000000000..9d95b826b --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/camera_node/camera_publisher.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import rospy +from sensor_msgs.msg import Image +from std_msgs.msg import Header +import numpy as np +from shadow_camera.realsense import RealSenseCamera + +class CameraPublisher: + def __init__(self): + rospy.init_node('camera_publisher', anonymous=True) + self.serial_number = rospy.get_param('~serial_number', None) + hz = rospy.get_param('~hz', 30) + rospy.loginfo(f"Serial number: {self.serial_number}") + + self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw') + self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw') + rospy.loginfo(f"RGB topic: {self.rgb_topic}") + rospy.loginfo(f"Depth topic: {self.depth_topic}") + + self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10) + # self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10) + self.rate = rospy.Rate(hz) # 30 Hz + self.camera = RealSenseCamera(self.serial_number, False) + + rospy.loginfo("Camera initialized") + + def publish_images(self): + self.camera.start_camera() + rospy.loginfo("Camera started") + while not rospy.is_shutdown(): + result = self.camera.read_frame(True, False, False, False) + if result is None: + rospy.logerr("Failed to read frame from camera") + continue + + color_image, depth_image, _, _ = result + + if color_image is not None or depth_image is not None: + rgb_msg = self.create_image_msg(color_image, "bgr8") + # depth_msg = self.create_image_msg(depth_image, "mono16") + + self.rgb_pub.publish(rgb_msg) + # self.depth_pub.publish(depth_msg) + # rospy.loginfo("Published RGB image") + else: + rospy.logwarn("Received None for color_image or depth_image") + + self.rate.sleep() + self.camera.stop_camera() + rospy.loginfo("Camera stopped") + + def create_image_msg(self, image, encoding): + msg = Image() + msg.header = Header() + msg.header.stamp = rospy.Time.now() + msg.height, msg.width = image.shape[:2] + msg.encoding = encoding + msg.is_bigendian = False + msg.step = image.strides[0] + msg.data = np.array(image).tobytes() + return msg + +if __name__ == '__main__': + try: + camera_publisher = CameraPublisher() + camera_publisher.publish_images() + except rospy.ROSInterruptException: + pass \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_fix.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_fix.py new file mode 100644 index 000000000..2f0474e08 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_fix.py @@ -0,0 +1,112 @@ +import os +import time +import h5py +import logging +from datetime import datetime + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +class DataCollector: + def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False): + self.arm_axis = 7 + self.dataset_dir = dataset_dir + self.dataset_name = dataset_name + self.max_timesteps = max_timesteps + self.camera_names = camera_names + self.state_dim = state_dim + self.overwrite = overwrite + self.data_dict = { + '/observations/qpos': [], + '/observations/qvel': [], + '/observations/effort': [], + '/action': [], + } + for cam_name in camera_names: + self.data_dict[f'/observations/images/{cam_name}'] = [] + + # 自动检测和创建数据集目录 + self.create_dataset_dir() + self.timesteps_collected = 0 + + + def create_dataset_dir(self): + # 按照年月日创建目录 + date_str = datetime.now().strftime("%Y%m%d") + self.dataset_dir = os.path.join(self.dataset_dir, date_str) + if not os.path.exists(self.dataset_dir): + os.makedirs(self.dataset_dir) + + def collect_data(self, ts, action): + self.data_dict['/observations/qpos'].append(ts.observation['qpos']) + self.data_dict['/observations/qvel'].append(ts.observation['qvel']) + self.data_dict['/observations/effort'].append(ts.observation['effort']) + self.data_dict['/action'].append(action) + for cam_name in self.camera_names: + self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) + + def save_data(self): + t0 = time.time() + # 保存数据 + with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root: + root.attrs['sim'] = False + obs = root.create_group('observations') + image = obs.create_group('images') + for cam_name in self.camera_names: + _ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8', + chunks=(1, 480, 640, 3)) + _ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim)) + _ = root.create_dataset('action', (self.max_timesteps, self.state_dim)) + + for name, array in self.data_dict.items(): + root[name][...] = array + print(f'Saving: {time.time() - t0:.1f} secs') + return True + + def load_hdf5(self, orign_path, file): + self.dataset_path = os.path.join(self.dataset_dir, file) + if not os.path.isfile(orign_path): + logging.error(f'Dataset does not exist at {orign_path}') + exit() + + with h5py.File(orign_path, 'r') as root: + self.is_sim = root.attrs['sim'] + self.qpos = root['/observations/qpos'][()] + self.qvel = root['/observations/qvel'][()] + self.effort = root['/observations/effort'][()] + self.action = root['/action'][()] + self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()] + for cam_name in root[f'/observations/images/'].keys()} + + self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis] + self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1] + + self.data_dict['/observations/qpos'] = self.qpos + self.data_dict['/observations/qvel'] = self.qvel + self.data_dict['/observations/effort'] = self.effort + self.data_dict['/action'] = self.action + for cam_name in self.camera_names: + self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name] + return True + +if __name__ == '__main__': + """ + 用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据 + + """ + dataset_dir = '/home/wang/project/shadow_rm_aloha/data' + orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128' + dataset_name = 'test' + max_timesteps = 300 + camera_names = ['cam_high','cam_low','cam_left','cam_right'] + state_dim = 16 + collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim) + for file in os.listdir(orign_dir): + collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim) + orign_path = os.path.join(orign_dir, file) + print(orign_path) + collector.load_hdf5(orign_path, file) + collector.save_data() + + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_validation.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_validation.py new file mode 100644 index 000000000..b16f87f15 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/data_validation.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +import os +import numpy as np +import h5py +import yaml +import logging +import time +from shadow_rm_robot.realman_arm import RmArm + +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") + +class DataValidator: + def __init__(self, config): + self.dataset_dir = config['dataset_dir'] + self.episode_idx = config['episode_idx'] + self.joint_names = config['joint_names'] + self.dataset_name = f'episode_{self.episode_idx}' + self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5') + self.state_names = self.joint_names + ["gripper"] + self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml') + + def load_hdf5(self): + if not os.path.isfile(self.dataset_path): + logging.error(f'Dataset does not exist at {self.dataset_path}') + exit() + + with h5py.File(self.dataset_path, 'r') as root: + self.is_sim = root.attrs['sim'] + self.qpos = root['/observations/qpos'][()] + # self.qvel = root['/observations/qvel'][()] + # self.effort = root['/observations/effort'][()] + self.action = root['/action'][()] + self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()] + for cam_name in root[f'/observations/images/'].keys()} + + def validate_data(self): + # 验证位置数据 + if not self.qpos.shape[1] == 14: + logging.error('qpos shape does not match expected number of joints') + return False + + logging.info('Data validation passed') + return True + + def control_robot(self): + self.arm.set_joint_position(self.qpos[0][0:6]) + for qpos in self.qpos: + logging.info(f'qpos: {qpos}') + self.arm.set_joint_canfd_position(qpos[7:13]) + self.arm.set_gripper_position(qpos[13]) + time.sleep(0.035) + + + def run(self): + self.load_hdf5() + if self.validate_data(): + self.control_robot() + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) + +if __name__ == '__main__': + config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml') + validator = DataValidator(config) + validator.run() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/visualize_aloha.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/visualize_aloha.py new file mode 100644 index 000000000..4b9b7f1c2 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_analysis/visualize_aloha.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +import os +import numpy as np +import cv2 +import h5py +import yaml +import logging +import matplotlib.pyplot as plt + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +class DataVisualizer: + def __init__(self, config): + self.dataset_dir = config['dataset_dir'] + self.episode_idx = config['episode_idx'] + self.dt = 1/config['FPS'] + self.joint_names = config['joint_names'] + self.state_names = self.joint_names + ["gripper"] + # self.camera_names = config['camera_names'] + + def join_file_path(self, file_name): + self.dataset_path = os.path.join(self.dataset_dir, file_name) + + def load_hdf5(self): + if not os.path.isfile(self.dataset_path): + logging.error(f'Dataset does not exist at {self.dataset_path}') + exit() + + with h5py.File(self.dataset_path, 'r') as root: + self.is_sim = root.attrs['sim'] + self.qpos = root['/observations/qpos'][()] + # self.qvel = root['/observations/qvel'][()] + # self.effort = root['/observations/effort'][()] + self.action = root['/action'][()] + self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()] + for cam_name in root[f'/observations/images/'].keys()} + + def save_videos(self, video, dt, video_path=None): + if isinstance(video, list): + cam_names = list(video[0].keys()) + h, w, _ = video[0][cam_names[0]].shape + w = w * len(cam_names) + fps = int(1 / dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for image_dict in video: + images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names] + out.write(np.concatenate(images, axis=1)) + out.release() + logging.info(f'Saved video to: {video_path}') + elif isinstance(video, dict): + cam_names = list(video.keys()) + all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2) + n_frames, h, w, _ = all_cam_videos.shape + fps = int(1 / dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for t in range(n_frames): + out.write(all_cam_videos[t][:, :, [2, 1, 0]]) + out.release() + logging.info(f'Saved video to: {video_path}') + + def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None): + label1, label2 = label_overwrite if label_overwrite else ('State', 'Command') + qpos = np.array(qpos_list) + command = np.array(command_list) + num_ts, num_dim = qpos.shape + logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}') + fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim)) + + all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names] + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(qpos[:, dim_idx], label=label1) + ax.plot(command[:, dim_idx], label=label2) + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + if ylim: + ax.set_ylim(ylim) + + plt.tight_layout() + plt.savefig(plot_path) + logging.info(f'Saved qpos plot to: {plot_path}') + plt.close() + + def visualize_single(self, data_list, label, plot_path, ylim=None): + data = np.array(data_list) + num_ts, num_dim = data.shape + fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim)) + + all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names] + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(data[:, dim_idx], label=label) + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + if ylim: + ax.set_ylim(ylim) + + plt.tight_layout() + plt.savefig(plot_path) + logging.info(f'Saved {label} plot to: {plot_path}') + plt.close() + + def visualize_timestamp(self, t_list): + plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png') + fig, axs = plt.subplots(2, 1, figsize=(10, 8)) + t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list]) + + axs[0].plot(np.arange(len(t_float)), t_float) + axs[0].set_title('Camera frame timestamps') + axs[0].set_xlabel('timestep') + axs[0].set_ylabel('time (sec)') + + axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:]) + axs[1].set_title('dt') + axs[1].set_xlabel('timestep') + axs[1].set_ylabel('time (sec)') + + plt.tight_layout() + plt.savefig(plot_path) + logging.info(f'Saved timestamp plot to: {plot_path}') + plt.close() + + def run(self): + for file_name in os.listdir(self.dataset_dir): + if file_name.endswith('.hdf5'): + self.join_file_path(file_name) + self.load_hdf5() + video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4')) + self.save_videos(self.image_dict, self.dt, video_path) + qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png')) + self.visualize_joints(self.qpos, self.action, qpos_plot_path) + # effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png')) + # self.visualize_single(self.effort, 'effort', effort_plot_path) + # error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png')) + # self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path) + # self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back + + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) + +if __name__ == '__main__': + config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml') + visualizer = DataVisualizer(config) + visualizer.run() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_collect.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_collect.py new file mode 100644 index 000000000..88de92bea --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_collect.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +import time +import yaml +import rospy +import dm_env +import numpy as np +import collections +from datetime import datetime +from sensor_msgs.msg import Image, JointState +from shadow_rm_robot.realman_arm import RmArm +from message_filters import Subscriber, ApproximateTimeSynchronizer + + +class DataSynchronizer: + def __init__(self, config_path="config"): + rospy.init_node("synchronizer", anonymous=True) + + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"] + self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"] + self.arm_axis = config["robot_env"]["arm_axis"] + self.camera_names = config["camera_names"] + + # 创建订阅者 + self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image) + self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image) + self.camera_bottom_sub = Subscriber( + config["ros_topics"]["camera_bottom"], Image + ) + self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image) + + self.left_slave_arm_sub = Subscriber( + config["ros_topics"]["left_slave_arm_sub"], JointState + ) + self.right_slave_arm_sub = Subscriber( + config["ros_topics"]["right_slave_arm_sub"], JointState + ) + self.left_slave_arm_pub = rospy.Publisher( + config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1 + ) + self.right_slave_arm_pub = rospy.Publisher( + config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1 + ) + + # 创建同步器 + self.ats = ApproximateTimeSynchronizer( + [ + self.camera_left_sub, + self.camera_right_sub, + self.camera_bottom_sub, + self.camera_head_sub, + self.left_slave_arm_sub, + self.right_slave_arm_sub, + ], + queue_size=1, + slop=0.1, + ) + + self.ats.registerCallback(self.callback) + self.ts = None + self.is_frist_step = True + + def callback( + self, + camera_left_img, + camera_right_img, + camera_bottom_img, + camera_head_img, + left_slave_arm, + right_slave_arm, + ): + + # 将ROS图像消息转换为NumPy数组 + camera_left_np_img = np.frombuffer( + camera_left_img.data, dtype=np.uint8 + ).reshape(camera_left_img.height, camera_left_img.width, -1) + camera_right_np_img = np.frombuffer( + camera_right_img.data, dtype=np.uint8 + ).reshape(camera_right_img.height, camera_right_img.width, -1) + camera_bottom_np_img = np.frombuffer( + camera_bottom_img.data, dtype=np.uint8 + ).reshape(camera_bottom_img.height, camera_bottom_img.width, -1) + camera_head_np_img = np.frombuffer( + camera_head_img.data, dtype=np.uint8 + ).reshape(camera_head_img.height, camera_head_img.width, -1) + + left_slave_arm_angle = left_slave_arm.position + left_slave_arm_velocity = left_slave_arm.velocity + left_slave_arm_force = left_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释 + # left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis] + + right_slave_arm_angle = right_slave_arm.position + right_slave_arm_velocity = right_slave_arm.velocity + right_slave_arm_force = right_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释 + # right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis] + + # 收集数据 + obs = collections.OrderedDict( + { + "qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]), + "qvel": np.concatenate( + [left_slave_arm_velocity, right_slave_arm_velocity] + ), + "effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]), + "images": { + self.camera_names[0]: camera_head_np_img, + self.camera_names[1]: camera_bottom_np_img, + self.camera_names[2]: camera_left_np_img, + self.camera_names[3]: camera_right_np_img, + }, + } + ) + self.ts = dm_env.TimeStep( + step_type=( + dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID + ), + reward=0.0, + discount=1.0, + observation=obs, + ) + + def reset(self): + + left_joint_state = JointState() + left_joint_state.header.stamp = rospy.Time.now() + left_joint_state.name = ["joint_j"] + left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1] + + right_joint_state = JointState() + right_joint_state.header.stamp = rospy.Time.now() + right_joint_state.name = ["joint_j"] + right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1] + + self.left_slave_arm_pub.publish(left_joint_state) + self.right_slave_arm_pub.publish(right_joint_state) + while self.ts is None: + time.sleep(0.002) + + return self.ts + + + + def step(self, target_angle): + self.is_frist_step = False + left_joint_state = JointState() + left_joint_state.header.stamp = rospy.Time.now() + left_joint_state.name = ["joint_canfd"] + left_joint_state.position = target_angle[0 : self.arm_axis + 1] + # print("left_joint_state: ", left_joint_state) + + right_joint_state = JointState() + right_joint_state.header.stamp = rospy.Time.now() + right_joint_state.name = ["joint_canfd"] + right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2] + # print("right_joint_state: ", right_joint_state) + + self.left_slave_arm_pub.publish(left_joint_state) + self.right_slave_arm_pub.publish(right_joint_state) + # time.sleep(0.013) + return self.ts + + def run(self): + rospy.loginfo("Starting ROS spin") + data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle]) + self.reset() + # print("data: ", data) + while not rospy.is_shutdown(): + self.step(data) + rospy.sleep(0.010) + + + +if __name__ == "__main__": + synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml") + start_time = time.time() + synchronizer.run() diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_synchronizer.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_synchronizer.py new file mode 100644 index 000000000..04775e81f --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/data_sub_process/aloha_data_synchronizer.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +import os +import time +import h5py +import yaml +import rospy +import dm_env +import numpy as np +import collections +from datetime import datetime +from std_msgs.msg import Int32MultiArray +from sensor_msgs.msg import Image, JointState +from message_filters import Subscriber, ApproximateTimeSynchronizer + + +class DataCollector: + def __init__( + self, + dataset_dir, + dataset_name, + max_timesteps, + camera_names, + state_dim, + overwrite=False, + ): + self.dataset_dir = dataset_dir + self.dataset_name = dataset_name + self.max_timesteps = max_timesteps + self.camera_names = camera_names + self.state_dim = state_dim + self.overwrite = overwrite + self.init_dict() + self.create_dataset_dir() + + def init_dict(self): + self.data_dict = { + "/observations/qpos": [], + "/observations/qvel": [], + "/observations/effort": [], + "/action": [], + } + for cam_name in self.camera_names: + self.data_dict[f"/observations/images/{cam_name}"] = [] + + def create_dataset_dir(self): + # 按照年月日创建目录 + date_str = datetime.now().strftime("%Y%m%d") + self.dataset_dir = os.path.join(self.dataset_dir, date_str) + if not os.path.exists(self.dataset_dir): + os.makedirs(self.dataset_dir) + + def create_file(self): + # 检查数据集名称是否存在,如果存在则递增名称 + counter = 0 + dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}") + if not self.overwrite: + while os.path.exists(dataset_path + ".hdf5"): + dataset_path = os.path.join( + self.dataset_dir, f"{self.dataset_name}_{counter}" + ) + counter += 1 + self.dataset_path = dataset_path + + def collect_data(self, ts, action): + self.data_dict["/observations/qpos"].append(ts.observation["qpos"]) + self.data_dict["/observations/qvel"].append(ts.observation["qvel"]) + self.data_dict["/observations/effort"].append(ts.observation["effort"]) + self.data_dict["/action"].append(action) + for cam_name in self.camera_names: + self.data_dict[f"/observations/images/{cam_name}"].append( + ts.observation["images"][cam_name] + ) + + def save_data(self): + self.create_file() + t0 = time.time() + # 保存数据 + with h5py.File( + self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2 + ) as root: + root.attrs["sim"] = False + obs = root.create_group("observations") + image = obs.create_group("images") + for cam_name in self.camera_names: + _ = image.create_dataset( + cam_name, + (self.max_timesteps, 480, 640, 3), + dtype="uint8", + chunks=(1, 480, 640, 3), + ) + _ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim)) + _ = root.create_dataset("action", (self.max_timesteps, self.state_dim)) + + for name, array in self.data_dict.items(): + root[name][...] = array + print(f"Saving: {time.time() - t0:.1f} secs") + return True + + +class DataSynchronizer: + def __init__(self, config_path="config"): + rospy.init_node("synchronizer", anonymous=True) + rospy.loginfo("ROS node initialized") + with open(config_path, "r") as file: + config = yaml.safe_load(file) + self.arm_axis = config["arm_axis"] + # 创建订阅者 + self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image) + self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image) + self.camera_bottom_sub = Subscriber( + config["ros_topics"]["camera_bottom"], Image + ) + self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image) + + self.left_master_arm_sub = Subscriber( + config["ros_topics"]["left_master_arm"], JointState + ) + self.left_slave_arm_sub = Subscriber( + config["ros_topics"]["left_slave_arm"], JointState + ) + self.right_master_arm_sub = Subscriber( + config["ros_topics"]["right_master_arm"], JointState + ) + self.right_slave_arm_sub = Subscriber( + config["ros_topics"]["right_slave_arm"], JointState + ) + self.left_aloha_state_pub = rospy.Subscriber( + config["ros_topics"]["left_aloha_state"], + Int32MultiArray, + self.aloha_state_callback, + ) + + rospy.loginfo("Subscribers created") + self.camera_names = config["camera_names"] + + # 创建同步器 + self.ats = ApproximateTimeSynchronizer( + [ + self.camera_left_sub, + self.camera_right_sub, + self.camera_bottom_sub, + self.camera_head_sub, + self.left_master_arm_sub, + self.left_slave_arm_sub, + self.right_master_arm_sub, + self.right_slave_arm_sub, + ], + queue_size=1, + slop=0.05, + ) + self.ats.registerCallback(self.callback) + rospy.loginfo("Time synchronizer created and callback registered") + + self.data_collector = DataCollector( + dataset_dir=config["dataset_dir"], + dataset_name=config["dataset_name"], + max_timesteps=config["max_timesteps"], + camera_names=config["camera_names"], + state_dim=config["state_dim"], + overwrite=config["overwrite"], + ) + self.timesteps_collected = 0 + self.begin_collect = False + self.last_time = None + + def callback( + self, + camera_left_img, + camera_right_img, + camera_bottom_img, + camera_head_img, + left_master_arm, + left_slave_arm, + right_master_arm, + right_slave_arm, + ): + if self.begin_collect: + self.timesteps_collected += 1 + rospy.loginfo( + f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}" + ) + else: + self.timesteps_collected = 0 + return + if self.timesteps_collected == 0: + return + current_time = time.time() + if self.last_time is not None: + frequency = 1.0 / (current_time - self.last_time) + rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz") + self.last_time = current_time + # 将ROS图像消息转换为NumPy数组 + camera_left_np_img = np.frombuffer( + camera_left_img.data, dtype=np.uint8 + ).reshape(camera_left_img.height, camera_left_img.width, -1) + camera_right_np_img = np.frombuffer( + camera_right_img.data, dtype=np.uint8 + ).reshape(camera_right_img.height, camera_right_img.width, -1) + camera_bottom_np_img = np.frombuffer( + camera_bottom_img.data, dtype=np.uint8 + ).reshape(camera_bottom_img.height, camera_bottom_img.width, -1) + camera_head_np_img = np.frombuffer( + camera_head_img.data, dtype=np.uint8 + ).reshape(camera_head_img.height, camera_head_img.width, -1) + + # 提取臂的角度,速度,力 + left_master_arm_angle = left_master_arm.position + # left_master_arm_velocity = left_master_arm.velocity + # left_master_arm_force = left_master_arm.effort + + left_slave_arm_angle = left_slave_arm.position + left_slave_arm_velocity = left_slave_arm.velocity + left_slave_arm_force = left_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释 + # left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis] + + right_master_arm_angle = right_master_arm.position + # right_master_arm_velocity = right_master_arm.velocity + # right_master_arm_force = right_master_arm.effort + + right_slave_arm_angle = right_slave_arm.position + right_slave_arm_velocity = right_slave_arm.velocity + right_slave_arm_force = right_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释 + # right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis] + + # 收集数据 + obs = collections.OrderedDict( + { + "qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]), + "qvel": np.concatenate( + [left_slave_arm_velocity, right_slave_arm_velocity] + ), + "effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]), + "images": { + self.camera_names[0]: camera_head_np_img, + self.camera_names[1]: camera_bottom_np_img, + self.camera_names[2]: camera_left_np_img, + self.camera_names[3]: camera_right_np_img, + }, + } + ) + print(self.camera_names[0]) + ts = dm_env.TimeStep( + step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs + ) + action = np.concatenate([left_master_arm_angle, right_master_arm_angle]) + self.data_collector.collect_data(ts, action) + + # 检查是否收集了足够的数据 + if self.timesteps_collected >= self.data_collector.max_timesteps: + self.data_collector.save_data() + + rospy.loginfo("Data collection complete") + + self.data_collector.init_dict() + self.begin_collect = False + self.timesteps_collected = 0 + + def aloha_state_callback(self, data): + if not self.begin_collect: + self.aloha_state = data.data + print(self.aloha_state[0], self.aloha_state[1]) + if self.aloha_state[0] == 1 and self.aloha_state[1] == 1: + self.begin_collect = True + + + def run(self): + rospy.loginfo("Starting ROS spin") + + rospy.spin() + +if __name__ == "__main__": + + synchronizer = DataSynchronizer( + "/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml" + ) + synchronizer.run() diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_data_publisher.launch b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_data_publisher.launch new file mode 100644 index 000000000..03dd27218 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_data_publisher.launch @@ -0,0 +1,63 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_eval.launch b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_eval.launch new file mode 100644 index 000000000..86a141a88 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_65_eval.launch @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_75_data_publisher.launch b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_75_data_publisher.launch new file mode 100644 index 000000000..b8c4975f5 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/launch/aloha_75_data_publisher.launch @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/mu_data/aloha_data_collect.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/mu_data/aloha_data_collect.py new file mode 100644 index 000000000..58c782688 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/mu_data/aloha_data_collect.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +import os +import cv2 +import time +import h5py +import yaml +import json +import rospy +import dm_env +import socket +import numpy as np +import collections +from datetime import datetime +from std_msgs.msg import Int32MultiArray +from sensor_msgs.msg import Image, JointState +from message_filters import Subscriber, ApproximateTimeSynchronizer + + +class DataCollector: + def __init__( + self, + dataset_dir, + dataset_name, + max_timesteps, + camera_names, + state_dim, + overwrite=False, + ): + self.dataset_dir = dataset_dir + self.dataset_name = dataset_name + self.max_timesteps = max_timesteps + self.camera_names = camera_names + self.state_dim = state_dim + self.overwrite = overwrite + self.init_dict() + self.create_dataset_dir() + + def init_dict(self): + self.data_dict = { + "/observations/qpos": [], + "/observations/qvel": [], + "/observations/effort": [], + "/action": [], + } + for cam_name in self.camera_names: + self.data_dict[f"/observations/images/{cam_name}"] = [] + + def create_dataset_dir(self): + # 按照年月日创建目录 + date_str = datetime.now().strftime("%Y%m%d") + self.dataset_dir = os.path.join(self.dataset_dir, date_str) + if not os.path.exists(self.dataset_dir): + os.makedirs(self.dataset_dir) + + def create_file(self): + # 检查数据集名称是否存在,如果存在则递增名称 + counter = 0 + dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}") + if not self.overwrite: + while os.path.exists(dataset_path + ".hdf5"): + dataset_path = os.path.join( + self.dataset_dir, f"{self.dataset_name}_{counter}" + ) + counter += 1 + self.dataset_path = dataset_path + + def collect_data(self, ts, action): + self.data_dict["/observations/qpos"].append(ts.observation["qpos"]) + self.data_dict["/observations/qvel"].append(ts.observation["qvel"]) + self.data_dict["/observations/effort"].append(ts.observation["effort"]) + self.data_dict["/action"].append(action) + for cam_name in self.camera_names: + self.data_dict[f"/observations/images/{cam_name}"].append( + ts.observation["images"][cam_name] + ) + + def save_data(self): + self.create_file() + t0 = time.time() + # 保存数据 + with h5py.File( + self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2 + ) as root: + root.attrs["sim"] = False + obs = root.create_group("observations") + image = obs.create_group("images") + for cam_name in self.camera_names: + _ = image.create_dataset( + cam_name, + (self.max_timesteps, 480, 640, 3), + dtype="uint8", + chunks=(1, 480, 640, 3), + ) + _ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim)) + _ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim)) + _ = root.create_dataset("action", (self.max_timesteps, self.state_dim)) + + for name, array in self.data_dict.items(): + root[name][...] = array + print(f"Saving: {time.time() - t0:.1f} secs") + return True + + +class DataSynchronizer: + def __init__(self, config_path="config"): + rospy.init_node("synchronizer", anonymous=True) + rospy.loginfo("ROS node initialized") + with open(config_path, "r") as file: + config = yaml.safe_load(file) + self.arm_axis = config["arm_axis"] + # 创建订阅者 + self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image) + self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image) + self.camera_bottom_sub = Subscriber( + config["ros_topics"]["camera_bottom"], Image + ) + self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image) + + self.left_master_arm_sub = Subscriber( + config["ros_topics"]["left_master_arm"], JointState + ) + self.left_slave_arm_sub = Subscriber( + config["ros_topics"]["left_slave_arm"], JointState + ) + self.right_master_arm_sub = Subscriber( + config["ros_topics"]["right_master_arm"], JointState + ) + self.right_slave_arm_sub = Subscriber( + config["ros_topics"]["right_slave_arm"], JointState + ) + self.left_aloha_state_pub = rospy.Subscriber( + config["ros_topics"]["left_aloha_state"], + Int32MultiArray, + self.aloha_state_callback, + ) + + rospy.loginfo("Subscribers created") + + # 创建同步器 + self.ats = ApproximateTimeSynchronizer( + [ + self.camera_left_sub, + self.camera_right_sub, + self.camera_bottom_sub, + self.camera_head_sub, + self.left_master_arm_sub, + self.left_slave_arm_sub, + self.right_master_arm_sub, + self.right_slave_arm_sub, + ], + queue_size=1, + slop=0.05, + ) + self.ats.registerCallback(self.callback) + rospy.loginfo("Time synchronizer created and callback registered") + + self.data_collector = DataCollector( + dataset_dir=config["dataset_dir"], + dataset_name=config["dataset_name"], + max_timesteps=config["max_timesteps"], + camera_names=config["camera_names"], + state_dim=config["state_dim"], + overwrite=config["overwrite"], + ) + self.timesteps_collected = 0 + self.begin_collect = False + self.last_time = None + + def callback( + self, + camera_left_img, + camera_right_img, + camera_bottom_img, + camera_head_img, + left_master_arm, + left_slave_arm, + right_master_arm, + right_slave_arm, + ): + if self.begin_collect: + self.timesteps_collected += 1 + rospy.loginfo( + f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}" + ) + else: + self.timesteps_collected = 0 + return + if self.timesteps_collected == 0: + return + current_time = time.time() + if self.last_time is not None: + frequency = 1.0 / (current_time - self.last_time) + rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz") + self.last_time = current_time + # 将ROS图像消息转换为NumPy数组 + camera_left_np_img = np.frombuffer( + camera_left_img.data, dtype=np.uint8 + ).reshape(camera_left_img.height, camera_left_img.width, -1) + camera_right_np_img = np.frombuffer( + camera_right_img.data, dtype=np.uint8 + ).reshape(camera_right_img.height, camera_right_img.width, -1) + camera_bottom_np_img = np.frombuffer( + camera_bottom_img.data, dtype=np.uint8 + ).reshape(camera_bottom_img.height, camera_bottom_img.width, -1) + camera_head_np_img = np.frombuffer( + camera_head_img.data, dtype=np.uint8 + ).reshape(camera_head_img.height, camera_head_img.width, -1) + + # 提取臂的角度,速度,力 + left_master_arm_angle = left_master_arm.position + # left_master_arm_velocity = left_master_arm.velocity + # left_master_arm_force = left_master_arm.effort + + left_slave_arm_angle = left_slave_arm.position + left_slave_arm_velocity = left_slave_arm.velocity + left_slave_arm_force = left_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释 + # left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis] + + right_master_arm_angle = right_master_arm.position + # right_master_arm_velocity = right_master_arm.velocity + # right_master_arm_force = right_master_arm.effort + + right_slave_arm_angle = right_slave_arm.position + right_slave_arm_velocity = right_slave_arm.velocity + right_slave_arm_force = right_slave_arm.effort + # 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释 + # right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis] + + # 收集数据 + obs = collections.OrderedDict( + { + "qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]), + "qvel": np.concatenate( + [left_slave_arm_velocity, right_slave_arm_velocity] + ), + "effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]), + "images": { + "cam_front": camera_head_np_img, + "cam_low": camera_bottom_np_img, + "cam_left": camera_left_np_img, + "cam_right": camera_right_np_img, + }, + } + ) + ts = dm_env.TimeStep( + step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs + ) + action = np.concatenate([left_master_arm_angle, right_master_arm_angle]) + self.data_collector.collect_data(ts, action) + + # 检查是否收集了足够的数据 + if self.timesteps_collected >= self.data_collector.max_timesteps: + self.data_collector.save_data() + + rospy.loginfo("Data collection complete") + + self.data_collector.init_dict() + self.begin_collect = False + self.timesteps_collected = 0 + + def aloha_state_callback(self, data): + if not self.begin_collect: + self.aloha_state = data.data + print(self.aloha_state[0], self.aloha_state[1]) + if self.aloha_state[0] == 1 and self.aloha_state[1] == 1: + self.begin_collect = True + + + def run(self): + rospy.loginfo("Starting ROS spin") + + rospy.spin() + +if __name__ == "__main__": + + synchronizer = DataSynchronizer( + "/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml" + ) + synchronizer.run() + + + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/package.xml b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/package.xml new file mode 100644 index 000000000..d0a471006 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/package.xml @@ -0,0 +1,31 @@ + + + shadow_rm_aloha + 0.0.1 + The shadow_rm_aloha package + + Your Name + + TODO + + catkin + + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + message_generation + message_runtime + + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + message_runtime + + + + + \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetArmStatus.srv b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetArmStatus.srv new file mode 100644 index 000000000..78b271ad3 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetArmStatus.srv @@ -0,0 +1,5 @@ +# GetArmStatus.srv + +--- +sensor_msgs/JointState joint_status + diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetImage.srv b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetImage.srv new file mode 100644 index 000000000..9aed5b1cc --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/GetImage.srv @@ -0,0 +1,4 @@ +# GetImage.srv +--- +bool success +sensor_msgs/Image image \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/MoveArm.srv b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/MoveArm.srv new file mode 100644 index 000000000..3ccaed68d --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/MoveArm.srv @@ -0,0 +1,4 @@ +# MoveArm.srv +float32[] joint_angle +--- +bool success \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/__init__.py b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/__init__.py new file mode 100644 index 000000000..b794fd409 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/src/shadow_rm_aloha/srv/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' diff --git a/realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py b/realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py new file mode 100644 index 000000000..a561343ee --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py @@ -0,0 +1,49 @@ +import multiprocessing as mp +import time + +def collect_data(arm_id, cam_id, data_queue, lock): + while True: + # 模拟数据采集 + arm_data = f"Arm {arm_id} data" + cam_data = f"Cam {cam_id} data" + + # 获取当前时间戳 + timestamp = time.time() + + # 将数据放入队列 + with lock: + data_queue.put((timestamp, arm_data, cam_data)) + + # 模拟高帧率 + time.sleep(0.01) + +def main(): + num_arms = 4 + num_cams = 4 + + # 创建队列和锁 + data_queue = mp.Queue() + lock = mp.Lock() + + # 创建进程 + processes = [] + for i in range(num_arms): + p = mp.Process(target=collect_data, args=(i, i, data_queue, lock)) + processes.append(p) + p.start() + + # 主进程处理数据 + try: + while True: + if not data_queue.empty(): + with lock: + timestamp, arm_data, cam_data = data_queue.get() + print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}") + except KeyboardInterrupt: + for p in processes: + p.terminate() + for p in processes: + p.join() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/test/test_aloha_data_synchronizer.py b/realman_src/realman_aloha/shadow_rm_aloha/test/test_aloha_data_synchronizer.py new file mode 100644 index 000000000..7e678837c --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/test/test_aloha_data_synchronizer.py @@ -0,0 +1,38 @@ +import os +import shutil +from datetime import datetime +from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector + +def test_create_dataset_dir(): + # 设置测试参数 + dataset_dir = './test_data/dataset' + dataset_name = 'test_episode' + max_timesteps = 100 + camera_names = ['cam1', 'cam2'] + overwrite = False + + # 清理旧的测试数据 + if os.path.exists(dataset_dir): + shutil.rmtree(dataset_dir) + + # 创建 DataCollector 实例并调用 create_dataset_dir + collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite) + + # 检查目录是否按预期创建 + date_str = datetime.now().strftime("%Y%m%d") + expected_dir = os.path.join(dataset_dir, date_str) + assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist." + + # 检查文件名是否按预期递增 + expected_file = os.path.join(expected_dir, dataset_name + '.hdf5') + assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}" + + # 再次调用 create_dataset_dir,检查文件名是否递增 + # collector.create_dataset_dir() + expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5') + assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}" + + print("All tests passed.") + +if __name__ == '__main__': + test_create_dataset_dir() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py b/realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py new file mode 100644 index 000000000..d7e808cab --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py @@ -0,0 +1,105 @@ + +import multiprocessing +import time +import random +import socket +import json +import logging + +# 设置日志级别 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class test_udp(): + def __init__(self): + arm_ip = '192.168.1.19' + arm_port = 8080 + self.arm =socket.socket() + self.arm.connect((arm_ip, arm_port)) + set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}} + self.arm.send(json.dumps(set_udp).encode('utf-8')) + state = self.arm.recv(1024) + + logging.info(f"Send data to {arm_ip}:{arm_port}: {state}") + + self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # 设置套接字选项,允许端口复用 + self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + local_ip = "192.168.1.101" + local_port = 8090 + self.udp_socket.bind((local_ip, local_port)) + self.BUFFER_SIZE = 1024 + + + + def set_udp(self): + + while True: + start_time = time.time() + data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE) + # 将接收到的UDP数据解码并解析为JSON + data = json.loads(data.decode('utf-8')) + end_time = time.time() + print(f"Received data {data}") + + udp_socket.close() + + +def collect_arm_data(arm_id, queue, event): + while True: + data = f"Arm {arm_id} data {random.random()}" + queue.put((arm_id, data)) + event.set() + time.sleep(1) + +def collect_camera_data(camera_id, queue, event): + while True: + data = f"Camera {camera_id} data {random.random()}" + queue.put((camera_id, data)) + event.set() + time.sleep(1) + +def main(): + arm_queues = [multiprocessing.Queue() for _ in range(4)] + camera_queues = [multiprocessing.Queue() for _ in range(4)] + arm_events = [multiprocessing.Event() for _ in range(4)] + camera_events = [multiprocessing.Event() for _ in range(4)] + + arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)] + camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)] + + for p in arm_processes + camera_processes: + p.start() + + try: + while True: + for event in arm_events + camera_events: + event.wait() + + for i in range(4): + if not arm_queues[i].empty(): + arm_id, arm_data = arm_queues[i].get() + print(f"Received from Arm {arm_id}: {arm_data}") + arm_events[i].clear() + + if not camera_queues[i].empty(): + camera_id, camera_data = camera_queues[i].get() + print(f"Received from Camera {camera_id}: {camera_data}") + camera_events[i].clear() + + time.sleep(0.1) + except KeyboardInterrupt: + for p in arm_processes + camera_processes: + p.terminate() + +if __name__ == "__main__": + main() + + +# if __name__ == "__main__": +# test_udp = test_udp() +# test_udp.set_udp() diff --git a/realman_src/realman_aloha/shadow_rm_robot/.gitignore b/realman_src/realman_aloha/shadow_rm_robot/.gitignore new file mode 100644 index 000000000..17567f898 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +*.pyc +*.pyo +*.pt \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/README.md b/realman_src/realman_aloha/shadow_rm_robot/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/realman_src/realman_aloha/shadow_rm_robot/config/rm_arm.yaml b/realman_src/realman_aloha/shadow_rm_robot/config/rm_arm.yaml new file mode 100644 index 000000000..fa05ca2fd --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/config/rm_arm.yaml @@ -0,0 +1,5 @@ +arm_ip: "192.168.1.18" +arm_port: 8080 +arm_axis: 6 +# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75 +arm_ki: [7, 7, 7, 3, 3, 3] # rm65 \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/config/servo_arm.yaml b/realman_src/realman_aloha/shadow_rm_robot/config/servo_arm.yaml new file mode 100644 index 000000000..7822e5f2b --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/config/servo_arm.yaml @@ -0,0 +1,5 @@ +port: /dev/ttyUSB0 +right_port: /dev/ttyUSB1 +baudrate: 460800 +hex_data: "55 AA 02 00 00 67" +arm_axis: 7 diff --git a/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂JSON通信协议v3.7.1.pdf b/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂JSON通信协议v3.7.1.pdf new file mode 100644 index 000000000..835b37c28 Binary files /dev/null and b/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂JSON通信协议v3.7.1.pdf differ diff --git a/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂接口函数说明(Python)V1.5.pdf b/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂接口函数说明(Python)V1.5.pdf new file mode 100644 index 000000000..1cc6b0ed4 Binary files /dev/null and b/realman_src/realman_aloha/shadow_rm_robot/docs/睿尔曼机械臂接口函数说明(Python)V1.5.pdf differ diff --git a/realman_src/realman_aloha/shadow_rm_robot/pyproject.toml b/realman_src/realman_aloha/shadow_rm_robot/pyproject.toml new file mode 100644 index 000000000..c6ef75d8d --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "shadow_rm_robot" +version = "0.1.0" +description = "Robot package, including operations such as reading and controlling robots" +readme = "README.md" +authors = ["Shadow "] +license = "MIT" +# include = ["realman_vision/pytransform/_pytransform.so",] +classifiers = [ + "Operating System :: POSIX :: Linux amd64", + "Programming Language :: Python :: 3.10", +] + +[tool.poetry.dependencies] +python = ">=3.10" +pyyaml = ">=6.0" +pyserial = ">=3.5" +pymodbus = ">=3.7" + + +[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。 +pytest = ">=8.3" +black = ">=24.10.0" + + + +[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。 + + +[tool.poetry.group.dev.dependencies] + + + +[build-system] +requires = ["poetry-core>=1.8.4"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/log_setting.py b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/log_setting.py new file mode 100644 index 000000000..6f1f549d1 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/log_setting.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# cython: language_level=3 +import os + +import logging +from logging.handlers import TimedRotatingFileHandler + +class CommonLog(object): + """ + 日志记录 + """ + + def __init__(self, logger, logname='web-log'): + self.logname = os.path.join(os.path.dirname(os.path.abspath(__file__)), '%s' % logname) + self.logger = logger + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = False # 禁止使用logger对象parent的处理器 + self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S') + + def __console(self, level, message): + # 创建一个FileHandler,用于写到本地 + + # fh = TimedRotatingFileHandler(self.logname, when='MIDNIGHT', interval=1, encoding='utf-8') + # # fh = logging.FileHandler(self.logname, 'a', encoding='utf-8') + # fh.suffix = '%Y-%m-%d.log' + # fh.setLevel(logging.DEBUG) + # fh.setFormatter(self.formatter) + # self.logger.addHandler(fh) + + # 创建一个StreamHandler,用于输出到控制台 + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(self.formatter) + self.logger.addHandler(ch) + + if level == 'info': + self.logger.info(message) + elif level == 'debug': + self.logger.debug(message) + elif level == 'warning': + self.logger.warning(message) + elif level == 'error': + self.logger.error(message, exc_info=1) # 显示错误栈 + # self.logger.error(message) + + elif level == 'error_': + self.logger.error(message) # 不显示错误栈 + + + # 这两行代码是为了避免日志输出重复问题 + self.logger.removeHandler(ch) + # self.logger.removeHandler(fh) + # # 关闭打开的文件 + # fh.close() + + def debug(self, message): + self.__console('debug', message) + + def info(self, message): + self.__console('info', message) + + def warning(self, message): + self.__console('warning', message) + + def error(self, message): + self.__console('error', message) + + def error_(self, message): + self.__console('error_', message) + + + + + diff --git a/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/realman_arm.py b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/realman_arm.py new file mode 100644 index 000000000..e769f48f4 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/realman_arm.py @@ -0,0 +1,288 @@ +import json +import yaml +import time +import logging +import socket +import numpy as np + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class RmArm: + def __init__(self, config_file="config.yaml"): + """初始化机械臂的网络连接并发送初始命令。 + + Args: + config_file (str): 配置文件的路径。 + """ + self.config = self._load_config(config_file) + self.arm_ip = self.config.get("arm_ip", "192.168.1.18") + self.get_vel = self.config.get("get_vel", True) + self.get_torque = self.config.get("get_torque", True) + arm_port = self.config.get("arm_port", 8080) + local_ip = self.config.get("local_ip", '192.168.1.101') + local_port = self.config.get("local_port", 8089) + + self.arm = socket.socket() + self.arm.connect((self.arm_ip, arm_port)) + + set_udp = {"command":"set_realtime_push","cycle":6,"enable":True,"port":local_port,"ip":local_ip,"custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}} + + self.arm.send(json.dumps(set_udp).encode('utf-8')) + _ = self.arm.recv(1024) + + self.arm_axis = self.config.get("arm_axis", 6) + self.arm_ki = self.config.get("arm_ki", [7, 7, 7, 3, 3, 3]) + + + self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.udp_socket.bind((local_ip, local_port)) + + self.cmd_get_current_arm_state = '{"command":"get_current_arm_state"}\r\n' + self.cmd_get_gripper_state = '{"command":"get_gripper_state"}\r\n' + + self.cmd_set_gripper_release = ( + '{"command": "set_gripper_release", "speed": 500, "block": false}\r\n' + ) + self.cmd_set_gripper_route = ( + '{"command":"set_gripper_route","min":0,"max":1000}\r\n' + ) + + self.arm.send(self.cmd_set_gripper_route.encode("utf-8")) + _ = self.arm.recv(1024) + + self.pre_gripper_actpos = None + self.cur_gripper_actpos = None + self.pre_actpos_time = None + self.cur_actpos_time = None + + def _load_config(self, config_file): + """加载配置文件。 + + Args: + config_file (str): 配置文件的路径。 + + Returns: + dict: 配置文件内容。 + """ + with open(config_file, "r") as file: + return yaml.safe_load(file) + + def _json_to_numpy(self, byte_data, key): + """将字节数据解析为 NumPy 数组。 + + Args: + byte_data (bytes): 字节数据。 + key (str): JSON 数据中的键。 + + Returns: + np.ndarray: 解析后的 NumPy 数组。 + """ + str_data = byte_data.decode("utf-8") + logging.debug(f"Received KEY: {key}") + logging.debug(f"Received JSON data: {str_data}") + try: + data_list = json.loads(str_data)[key] + if isinstance(data_list, dict): + return data_list + except KeyError: + logging.error(f"Key '{key}' not found in JSON data") + logging.error(f"Received JSON data: {str_data}") + return None + return np.array(data_list, dtype=float) + + def set_joint_position(self, joint_angle): + """设置机械臂的位置。 + + Args: + arm_pos (np.ndarray): 机械臂的位置 + + """ + joint_angle = np.array(joint_angle) + data = np.floor(joint_angle * 1000).astype(int).tolist() + cmd = ( + json.dumps( + {"command": "movej", "joint": data, "block": True, "v": 40, "r": 0} + ) + + "\r\n" + ) + self.arm.send(cmd.encode("utf-8")) + # TODO: Pending + state = self.arm.recv(1024) + # state = self.arm.recv(1024) + + def set_joint_canfd_position(self, joint_angle): + """设置机械臂的位置。 + + Args: + arm_pos (np.ndarray): 机械臂的位置 + """ + joint_angle = np.array(joint_angle) + data = np.floor(joint_angle * 1000).astype(int).tolist() + cmd = ( + json.dumps({"command": "movej_canfd", "joint": data, "follow": False}) + + "\r\n" + ) + self.arm.send(cmd.encode("utf-8")) + + def set_gripper_position(self, actpos): + """设置夹爪的位置。 + + Args: + actpos (np.ndarray): 夹爪的位置,单位为毫米。 + """ + data = np.array(actpos) * 1000 + data = np.floor(data).astype(int).tolist() + cmd = ( + json.dumps( + { + "command": "set_gripper_position", + "position": data, + "block": False, + } + ) + + "\r\n" + ) + self.arm.send(cmd.encode("utf-8")) + aaa = self.arm.recv(1024) + # print(aaa) + + def _update_state(self, gripper_actpos, actpos_time): + """更新关节和夹爪状态及时间。 + + Args: + actpos_time (float): 夹爪时间戳。 + """ + self.pre_gripper_actpos, self.cur_gripper_actpos = self.cur_gripper_actpos, gripper_actpos + self.pre_actpos_time, self.cur_actpos_time = self.cur_actpos_time, actpos_time + + def get_arm_data(self): + """获取机械臂数据""" + data, addr = self.udp_socket.recvfrom(1024) + data = json.loads(data.decode('utf-8')) + # logging.info(f"Received data: {data}") + joint_angle = np.array(data['joint_status']['joint_position']) * 0.001 + joint_velocity = np.array(data['joint_status']['joint_speed']) * 0.001 if self.get_vel else None + joint_current = np.array(data['joint_status']['joint_current']) / 1000000 if self.get_torque else None + joint_torque = self.current_to_torque(joint_current) if self.get_torque else None + aloha_state = data['aloha_state'] + # logging.info(f"Time consumed: {time.time() - start_time}") + + result = {'joint_angle': joint_angle, 'aloha_state': aloha_state} + if self.get_vel: + result['joint_velocity'] = joint_velocity + if self.get_torque: + result['joint_torque'] = joint_torque + + return result + + + def get_gripper_data(self): + """获取夹爪数据""" + try: + actpos_time = time.time() + self.arm.send(self.cmd_get_gripper_state.encode("utf-8")) + # gripper_qpos = self.arm.recv(1024) + while True: + gripper_qpos = self.arm.recv(1024) + data = json.loads(gripper_qpos.decode("utf-8")) + if "actpos" in data: + break + else: + self.arm.send(self.cmd_get_gripper_state.encode("utf-8")) + gripper_actpos = self._json_to_numpy(gripper_qpos, "actpos") * 0.001 + gripper_velocity = self.get_gripper_velocity() if self.get_vel else None + gripper_force = self._json_to_numpy(gripper_qpos, "current_force") / 100 if self.get_torque else None + + result = {'gripper_actpos': gripper_actpos} + if self.get_vel: + result['gripper_velocity'] = gripper_velocity + if self.get_torque: + result['gripper_force'] = gripper_force + + return result + except Exception as e: + logging.error(f"Error getting gripper data: {e}") + return None + + + def get_integrate_data(self): + """获取整合数据""" + arm_data = self.get_arm_data() + gripper_data = self.get_gripper_data() + + if not arm_data or not gripper_data: + return None + + result = {'aloha_state': arm_data['aloha_state'], 'arm_angle': np.append(arm_data['joint_angle'], gripper_data['gripper_actpos'])} + if self.get_vel: + result['arm_velocity'] = np.append(arm_data['joint_velocity'], gripper_data['gripper_velocity']) + if self.get_torque: + result['arm_torque'] = arm_data['joint_torque'] + [gripper_data['gripper_force']] + + return result + + def get_gripper_velocity(self): + """获取夹爪速度""" + if self.pre_actpos_time is None or self.cur_actpos_time is None: + logging.debug("Previous or current joint positions are not available.") + return 0 + delta_time = self.cur_actpos_time - self.pre_actpos_time + return (self.cur_gripper_actpos - self.pre_gripper_actpos["gripper"]) / delta_time if self.cur_gripper_actpos is not None else 0 + + def current_to_torque(self, current): + """将电流转换为扭矩""" + return [c * k for c, k in zip(current, self.arm_ki)] + + def get_arm_position(self): + """获取机械臂的位置。 + + Returns: + dict: 机械臂的位置,单位为毫米和弧度。 + 包含以下键: + - 'x': x 轴位置 + - 'y': y 轴位置 + - 'z': z 轴位置 + - 'roll': 滚转角 + - 'pitch': 俯仰角 + - 'yaw': 偏航角 + - 单位 : mm, rad + """ + self.arm.send(self.cmd_get_current_arm_state.encode("utf-8")) + _arm_state = self.arm.recv(1024) + arm_state = self._json_to_numpy(_arm_state, "arm_state") + arm_pos = np.array(arm_state["pose"], dtype=float) * 0.001 + + return { + "x": arm_pos[0], + "y": arm_pos[1], + "z": arm_pos[2], + "roll": arm_pos[3], + "pitch": arm_pos[4], + "yaw": arm_pos[5], + } + +if __name__ == "__main__": + arm_left = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml") + # arm_right = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml") + # test_left_narry = [7.235, 31.816, 51.237, 2.463, 91.054, 12.04] + test_right_narry = [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868] + while True: + start_time = time.time() + arm_left.set_gripper_position(0.2) + # left_qpos = arm_left.get_integrate_data() + left_qpos = arm_left.get_gripper_data() + logging.info(left_qpos) + # right_qpos = arm_right.get_arm_data() + # logging.info(left_qpos) + # logging.info(right_qpos) + # time.sleep(0.02) + # arm_right.set_joint_canfd_position(test_right_narry) + # arm_right.set_gripper_position(0.2) + + + logging.info(f"Time consumed: {time.time() - start_time}") \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/robotic_arm.py b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/robotic_arm.py new file mode 100644 index 000000000..a8a01d7a3 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/robotic_arm.py @@ -0,0 +1,5603 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# cython: language_level=3 + +import os +import time +import ctypes +import logging +import platform +from enum import IntEnum +from typing import Tuple, List + +# 此处为了兼容绝对路径和相对路径写了多种导入方式,推荐用户根据包的结构选择一种清晰的导入方式 +if __package__ is None or __package__ == '': + # 当作为脚本运行时,__package__ 为 None 或者空字符串 + from .log_setting import CommonLog +else: + # 当作为模块导入时,__package__ 为模块的包名 + from .log_setting import CommonLog + +logger_ = logging.getLogger(__name__) +logger_ = CommonLog(logger_) + +# 定义机械臂型号 +RM65 = 65 +RML63_I = 631 +RML63_II = 632 +ECO65 = 651 +RM75 = 75 +ECO62 = 62 +GEN72 = 72 + +ARM_DOF = 7 +MOVEJ_CANFD_CB = 0x0001 # 角度透传非阻 +MOVEP_CANFD_CB = 0x0002 # 位姿透传非阻 +FORCE_POSITION_MOVE_CB = 0x0003 # 力位混合透传 + +errro_message = {1: '1: CONTROLLER_DATA_RETURN_FALSE', 2: "2: INIT_MODE_ERR", 3: '3: INIT_TIME_ERR', + 4: '4: INIT_SOCKET_ERR', 5: '5: SOCKET_CONNECT_ERR', 6: '6: SOCKET_SEND_ERR', 7: '7: SOCKET_TIME_OUT', + 8: '8: UNKNOWN_ERR', 9: '9: CONTROLLER_DATA_LOSE_ERR', 10: '10: CONTROLLER_DATE_ARR_NUM_ERR', + 11: '11: WRONG_DATA_TYPE', 12: '12: MODEL_TYPE_ERR', 13: '13: CALLBACK_NOT_FIND', + 14: '14: ARM_ABNORMAL_STOP', + 15: '15: TRAJECTORY_FILE_LENGTH_ERR', 16: '16: TRAJECTORY_FILE_CHECK_ERR', + 17: '17: TRAJECTORY_FILE_READ_ERR', 18: '18: CONTROLLER_BUSY', 19: '19: ILLEGAL_INPUT', + 20: '20: QUEUE_LENGTH_FULL', + 21: '21 CALCULATION_FAILED', 22: '22: FILE_OPEN_ERR', 23: '23: FORCE_AUTO_STOP', + 24: '24: DRAG_TEACH_FLAG_FALSE', 25: '25: LISTENER_RUNNING_ERR'} + + +class POS_TEACH_MODES(IntEnum): + X_Dir = 0 # X轴方向 + Y_Dir = 1 # Y轴方向 + Z_Dir = 2 # Z轴方向 + + +class ARM_CTRL_MODES(IntEnum): + None_Mode = 0, # 无规划 + Joint_Mode = 1, # 关节空间规划 + Line_Mode = 2, # 笛卡尔空间直线规划 + Circle_Mode = 3, # 笛卡尔空间圆弧规划 + Replay_Mode = 4, # 拖动示教轨迹复现 + Moves_Mode = 5 # 样条曲线运动 + + +class RobotType(IntEnum): + RM65 = 0 + RM75 = 1 + RML63I = 2 + RML63II = 3 + RML63III = 4 + ECO65 = 5 + ECO62 = 6 + GEN72 = 7 + UNIVERSAL = 8 + + +class SensorType(IntEnum): + B = 0 + ZF = 1 + SF = 2 + + +class JOINT_STATE(ctypes.Structure): + _fields_ = [ + # ("joint", ctypes.c_float * ARM_DOF), + ("temperature", ctypes.c_float * ARM_DOF), + ("voltage", ctypes.c_float * ARM_DOF), + ("current", ctypes.c_float * ARM_DOF), + ("en_state", ctypes.c_byte * ARM_DOF), + ("err_flag", ctypes.c_uint16 * ARM_DOF), + ("sys_err", ctypes.c_uint16), + ] + + +class Quat(ctypes.Structure): + _fields_ = [ + ('w', ctypes.c_float), + ('x', ctypes.c_float), + ('y', ctypes.c_float), + ('z', ctypes.c_float) + ] + + +class Pos(ctypes.Structure): + _fields_ = [ + ('x', ctypes.c_float), + ('y', ctypes.c_float), + ('z', ctypes.c_float) + ] + + +class Euler(ctypes.Structure): + _fields_ = [ + ('rx', ctypes.c_float), + ('ry', ctypes.c_float), + ('rz', ctypes.c_float) + ] + + +class Pose(ctypes.Structure): + _fields_ = [ + ('position', Pos), # 位置 + ('quaternion', Quat), # 四元数 + ('euler', Euler) # 欧拉角 + ] + + +class Matrix(ctypes.Structure): + _fields_ = [ + ('irow', ctypes.c_short), + ('iline', ctypes.c_short), + ('data', (ctypes.c_float * 4) * 4) + ] + + +class FRAME_NAME(ctypes.Structure): + _fields_ = [('name', ctypes.c_char * 12)] + + +class FRAME(ctypes.Structure): + _fields_ = [('frame_name', FRAME_NAME), # 坐标系名称 + ('pose', Pose), # 坐标系位姿 + ('payload', ctypes.c_float), # 坐标系末端负载重量 + ('x', ctypes.c_float), # 坐标系末端负载位置 + ('y', ctypes.c_float), # 坐标系末端负载位置 + ('z', ctypes.c_float)] # 坐标系末端负载位置 + + +class POSE_QUAT(ctypes.Structure): + _fields_ = [('px', ctypes.c_float), + ('py', ctypes.c_float), + ('pz', ctypes.c_float), + ('w', ctypes.c_float), + ('x', ctypes.c_float), + ('y', ctypes.c_float), + ('z', ctypes.c_float)] + + +class ExpandConfig(ctypes.Structure): + _fields_ = [("rpm_max", ctypes.c_int), + ("rpm_acc", ctypes.c_int), + ("conversin_coe", ctypes.c_int), + ("limit_min", ctypes.c_int), + ("limit_max", ctypes.c_int)] + + +class WiFi_Info(ctypes.Structure): + _fields_ = [("channel", ctypes.c_int), + ("ip", ctypes.c_char * 16), + ("mac", ctypes.c_char * 18), + ("mask", ctypes.c_char * 16), + ("mode", ctypes.c_char * 5), + ("password", ctypes.c_char * 16), + ("ssid", ctypes.c_char * 32)] + + +CUR_PATH = os.path.dirname(os.path.realpath(__file__)) + +# 获取当前操作系统的名称 +os_name = platform.system() + +if os_name == 'Windows': + dllPath = os.path.join(CUR_PATH, "RM_Base.dll") +elif os_name == 'Linux': + dllPath = os.path.join(CUR_PATH, "libRM_Base.so") +else: + print("当前操作系统:", os_name) + + +class CallbackData(ctypes.Structure): + _fields_ = [ + ("sockhand", ctypes.c_int), # 返回调用时句柄 + ("codeKey", ctypes.c_int), # 调用透传接口类型 + ("errCode", ctypes.c_int), # API解析错误码 + ("pose", Pose), # 当前位姿 + ("joint", ctypes.c_float * 7), # 当前关节角度 + ("nforce", ctypes.c_int), # 力控方向上所受的力 + ("sys_err", ctypes.c_uint16) # 系统错误 + ] + + +# Define the JointStatus structure +class JointStatus(ctypes.Structure): + _fields_ = [ + ("joint_current", ctypes.c_float * ARM_DOF), + ("joint_en_flag", ctypes.c_ubyte * ARM_DOF), + ("joint_err_code", ctypes.c_uint16 * ARM_DOF), + ("joint_position", ctypes.c_float * ARM_DOF), + ("joint_temperature", ctypes.c_float * ARM_DOF), + ("joint_voltage", ctypes.c_float * ARM_DOF) + ] + + +# Define the ForceData structure +class ForceData(ctypes.Structure): + _fields_ = [ + ("force", ctypes.c_float * 6), + ("zero_force", ctypes.c_float * 6), + ("coordinate", ctypes.c_int) + ] + + +# Define the RobotStatus structure +class RobotStatus(ctypes.Structure): + _fields_ = [ + ("errCode", ctypes.c_int), # API解析错误码 + ("arm_ip", ctypes.c_char_p), # 返回消息的机械臂IP + ("arm_err", ctypes.c_uint16), # 机械臂错误码 + ("joint_status", JointStatus), # 当前关节状态 + ("force_sensor", ForceData), # 力数据 + ("sys_err", ctypes.c_uint16), # 系统错误吗 + ("waypoint", Pose) # 路点信息 + ] + + +CANFD_Callback = ctypes.CFUNCTYPE(None, CallbackData) +RealtimePush_Callback = ctypes.CFUNCTYPE(None, RobotStatus) + + +class TrajectoryData(ctypes.Structure): + _fields_ = [ + ("id", ctypes.c_int), + ("size", ctypes.c_int), + ("speed", ctypes.c_int), + ("trajectory_name", ctypes.c_char * 32) + ] + + +class ProgramTrajectoryData(ctypes.Structure): + _fields_ = [ + ("page_num", ctypes.c_int), + ("page_size", ctypes.c_int), + ("total_size", ctypes.c_int), + ("vague_search", ctypes.c_char * 32), + ("list", TrajectoryData * 100) + ] + + +class ProgramRunState(ctypes.Structure): + _fields_ = [ + ("run_state", ctypes.c_int), + ("id", ctypes.c_int), + ("plan_num", ctypes.c_int), + ("loop_num", ctypes.c_int * 10), + ("loop_cont", ctypes.c_int * 10), + ("step_mode", ctypes.c_int), + ("plan_speed", ctypes.c_int) + ] + + +# 电子围栏名称 +class ElectronicFenceNames(ctypes.Structure): + _fields_ = [('name', ctypes.c_char * 12)] + + +# 电子围栏配置参数 +class ElectronicFenceConfig(ctypes.Structure): + _fields_ = [ + ("form", ctypes.c_int), # 形状,1 表示立方体,2 表示点面矢量平面,3 表示球体 + ("name", ctypes.c_char * 12), # 几何模型名称,不超过10个字节,支持字母、数字、下划线 + # 立方体 + ("x_min_limit", ctypes.c_float), # 立方体基于世界坐标系 X 方向最小位置,单位 0.001m + ("x_max_limit", ctypes.c_float), # 立方体基于世界坐标系 X 方向最大位置,单位 0.001m + ("y_min_limit", ctypes.c_float), # 立方体基于世界坐标系 Y 方向最小位置,单位 0.001m + ("y_max_limit", ctypes.c_float), # 立方体基于世界坐标系 Y 方向最大位置,单位 0.001m + ("z_min_limit", ctypes.c_float), # 立方体基于世界坐标系 Z 方向最小位置,单位 0.001m + ("z_max_limit", ctypes.c_float), # 立方体基于世界坐标系 Z 方向最大位置,单位 0.001m + # 点面矢量平面 + ("x1", ctypes.c_float), # 表示点面矢量平面三点法中的第一个点坐标,单位 0.001m + ("z1", ctypes.c_float), + ("y1", ctypes.c_float), + ("x2", ctypes.c_float), # 表示点面矢量平面三点法中的第二个点坐标,单位 0.001m + ("y2", ctypes.c_float), + ("z2", ctypes.c_float), + ("x3", ctypes.c_float), # 表示点面矢量平面三点法中的第三个点坐标,单位 0.001m + ("y3", ctypes.c_float), + ("z3", ctypes.c_float), + # 球体 + ("radius", ctypes.c_float), # 表示半径,单位 0.001m + ("x", ctypes.c_float), # 表示球心在世界坐标系 X 轴、Y轴、Z轴的坐标,单位 0.001m + ("y", ctypes.c_float), + ("z", ctypes.c_float), + ] + + def to_output(self): + name = self.name.decode("utf-8").strip() # 去除字符串两端的空白字符 + output_dict = {"name": name} + + if self.form == 1: # 立方体 + output_dict.update({ + "form": "cube", + "x_min_limit": float(format(self.x_min_limit, ".3f")), + "x_max_limit": float(format(self.x_max_limit, ".3f")), + "y_min_limit": float(format(self.y_min_limit, ".3f")), + "y_max_limit": float(format(self.y_max_limit, ".3f")), + "z_min_limit": float(format(self.z_min_limit, ".3f")), + "z_max_limit": float(format(self.z_max_limit, ".3f")), + }) + elif self.form == 2: # 点面矢量平面 + output_dict.update({ + "form": "point_face_vector_plane", + "x1": float(format(self.x1, ".3f")), + "y1": float(format(self.y1, ".3f")), + "z1": float(format(self.z1, ".3f")), + "x2": float(format(self.x2, ".3f")), + "y2": float(format(self.y2, ".3f")), + "z2": float(format(self.z2, ".3f")), + "x3": float(format(self.x3, ".3f")), + "y3": float(format(self.y3, ".3f")), + "z3": float(format(self.z3, ".3f")), + }) + elif self.form == 3: # 球体 + output_dict.update({ + "form": "sphere", + "radius": float(format(self.radius, ".3f")), + "x": float(format(self.x, ".3f")), + "y": float(format(self.y, ".3f")), + "z": float(format(self.z, ".3f")), + }) + + return output_dict + + +# 夹爪状态 +class GripperState(ctypes.Structure): + _fields_ = [ + ("enable_state", ctypes.c_bool), # 夹爪使能标志,0 表示未使能,1 表示使能 + ("status", ctypes.c_int), # 夹爪在线状态,0 表示离线, 1表示在线 + ("error", ctypes.c_int), # 夹爪错误信息,低8位表示夹爪内部的错误信息bit5-7 保留bit4 内部通bit3 驱动器bit2 过流 bit1 过温bit0 + ("mode", ctypes.c_int), # 当前工作状态:1 夹爪张开到最大且空闲,2 夹爪闭合到最小且空闲,3 夹爪停止且空闲,4 夹爪正在闭合,5 夹爪正在张开,6 夹爪 + ("current_force", ctypes.c_int), # 夹爪当前的压力,单位g + ("temperature", ctypes.c_int), # 当前温度,单位℃ + ("actpos", ctypes.c_int), # 夹爪开口度 + ] + + +class CtrlInfo(ctypes.Structure): + _fields_ = [ + ("build_time", ctypes.c_char * 20), + ("version", ctypes.c_char * 10), + ] + + +class DynamicInfo(ctypes.Structure): + _fields_ = [ + ("model_version", ctypes.c_char * 5), + ] + + +class PlanInfo(ctypes.Structure): + _fields_ = [ + ("build_time", ctypes.c_char * 20), + ("version", ctypes.c_char * 10), + ] + + +class AlgorithmInfo(ctypes.Structure): + _fields_ = [ + ("version", ctypes.c_char * 10), + ] + + +# 机械臂软件信息 +class ArmSoftwareInfo(ctypes.Structure): + _fields_ = [ + ("product_version", ctypes.c_char * 10), + ("algorithm_info", AlgorithmInfo), + ("ctrl_info", CtrlInfo), + ("dynamic_info", DynamicInfo), + ("plan_info", PlanInfo), + ] + + +# 定义ToolEnvelope结构体 +class ToolEnvelope(ctypes.Structure): + _fields_ = [ + ("name", ctypes.c_char * 12), + ("radius", ctypes.c_float), # 工具包络球体的半径,单位 m + ("x", ctypes.c_float), + ("y", ctypes.c_float), + ("z", ctypes.c_float), + ] + + def __init__(self, name=None, radius=None, x=None, y=None, z=None): + if all(param is None for param in [name, radius, x, y, z]): + return + else: + # 转换name + self.name = name.encode('utf-8') + self.radius = radius + self.x = x + self.y = y + self.z = z + + def to_output(self): + name = self.name.decode("utf-8") + # 创建一个字典,包含ToolEnvelope的所有属性 + output_dict = { + "name": name, + "radius": float(format(self.radius, ".3f")), + "x": float(format(self.x, ".3f")), + "y": float(format(self.y, ".3f")), + "z": float(format(self.z, ".3f")) + } + return output_dict + + +# 定义ToolEnvelopeList结构体,其中包含一个ToolEnvelope数组 +class ToolEnvelopeList(ctypes.Structure): + _fields_ = [ + ("tool_name", ctypes.c_char * 12), # 坐标系名称 + ("list", ToolEnvelope * 5), # 包络参数列表,最多5个 + ("count", ctypes.c_int), # 包络参数 + ] + + def __init__(self, tool_name=None, list=None, count=None): + if all(param is None for param in [tool_name, list, count]): + return + else: + # 转换tool_name + self.tool_name = tool_name.encode('utf-8') + + self.list = (ToolEnvelope * 5)(*list) + self.count = count + + def to_output(self): + name = self.tool_name.decode("utf-8") + + output_dict = { + "tool_name": name, + "List": [self.list[i].to_output() for i in range(self.count)], + "count": self.count, + } + return output_dict + + +class Waypoint(ctypes.Structure): + _fields_ = [("point_name", ctypes.c_char * 16), + ("joint", ctypes.c_float * ARM_DOF), + ("pose", Pose), + ("work_frame", ctypes.c_char * 12), + ("tool_frame", ctypes.c_char * 12), + ("time", ctypes.c_char * 20)] + + def __init__(self, point_name=None, joint=None, pose=None, work_frame=None, tool_frame=None, time=''): + if all(param is None for param in [point_name, joint, pose, work_frame, tool_frame]): + return + else: + # 转换point_name + self.point_name = point_name.encode('utf-8') + + # 转换joint + self.joint = (ctypes.c_float * ARM_DOF)(*joint) + + pose_value = Pose() + pose_value.position = Pos(*pose[:3]) + pose_value.euler = Euler(*pose[3:]) + + self.pose = pose_value + + # 转换work_frame和tool_frame + self.work_frame = work_frame.encode('utf-8') + self.tool_frame = tool_frame.encode('utf-8') + + # 转换time + self.time = time.encode('utf-8') + + def to_output(self): + name = self.point_name.decode("utf-8") + wname = self.work_frame.decode("utf-8") + tname = self.tool_frame.decode("utf-8") + time = self.time.decode("utf-8") + position = self.pose.position + euler = self.pose.euler + + output_dict = { + "point_name": name, + "joint": [float(format(self.joint[i], ".3f")) for i in range(ARM_DOF)], + "pose": [position.x, position.y, position.z, euler.rx, euler.ry, euler.rz], + "work_frame": wname, + "tool_frame": tname, + "time": time, + } + return output_dict + + +# 定义WaypointsList结构体 +class WaypointsList(ctypes.Structure): + _fields_ = [("page_num", ctypes.c_int), + ("page_size", ctypes.c_int), + ("total_size", ctypes.c_int), + ("vague_search", ctypes.c_char * 32), + ("points_list", Waypoint * 100)] + + def to_output(self): + vague_search = self.vague_search.decode("utf-8") + non_empty_outputs = [] + for i in range(self.total_size): + if self.points_list[i].point_name != b'': # 判断列表是否为空 + output = self.points_list[i].to_output() + non_empty_outputs.append(output) + + output_dict = { + "total_size": self.total_size, + "vague_search": vague_search, + "points_list": non_empty_outputs, + } + return output_dict + + +class Send_Project_Params(ctypes.Structure): + _fields_ = [ + ('project_path', ctypes.c_char * 300), + ('project_path_len', ctypes.c_int), + ('plan_speed', ctypes.c_int), + ('only_save', ctypes.c_int), + ('save_id', ctypes.c_int), + ('step_flag', ctypes.c_int), + ('auto_start', ctypes.c_int), + ] + + def __init__(self, project_path: str = None, plan_speed: int = None, only_save: int = None, save_id: int = None, + step_flag: int = None, auto_start: int = None): + """ + 在线编程文件下发结构体 + + @param project_path (str, optional): 下发文件路径文件路径及名称,默认为None + @param plan_speed (int, optional): 规划速度比例系数,默认为None + @param only_save (int, optional): 0-运行文件,1-仅保存文件,不运行,默认为None + @param save_id (int, optional): 保存到控制器中的编号,默认为None + @param step_flag (int, optional): 设置单步运行方式模式,1-设置单步模式 0-设置正常运动模式,默认为None + @param auto_start (int, optional): 设置默认在线编程文件,1-设置默认 0-设置非默认,默认为None + """ + if all(param is None for param in [project_path, plan_speed, only_save, save_id, step_flag, auto_start]): + return + else: + if project_path is not None: + self.project_path = project_path.encode('utf-8') + + # 路径及名称长度 + self.project_path_len = len(project_path.encode('utf-8')) + 1 # 包括null终止符 + + # 规划速度比例系数 + self.plan_speed = plan_speed if plan_speed is not None else 0 + # 0-运行文件,1-仅保存文件,不运行 + self.only_save = only_save if only_save is not None else 0 + # 保存到控制器中的编号 + self.save_id = save_id if save_id is not None else 0 + # 设置单步运行方式模式,1-设置单步模式 0-设置正常运动模式 + self.step_flag = step_flag if step_flag is not None else 0 + # 设置默认在线编程文件,1-设置默认 0-设置非默认 + self.auto_start = auto_start if auto_start is not None else 0 + + +class Set_Joint(): + def Set_Joint_Speed(self, joint_num, speed, block=True): + """ + Set_Joint_Speed 设置关节最大速度 + ArmSocket socket句柄 + joint_num 关节序号,1~7 + speed 关节转速,单位:°/s + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Joint_Speed.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Speed.restype = self.check_error + + tag = self.pDll.Set_Joint_Speed(self.nSocket, joint_num, speed, block) + + logger_.info(f'Set_Joint_Speed:{tag}') + + return tag + + def Set_Joint_Acc(self, joint_num, acc, block=True): + """ + Set_Joint_Acc 设置关节最大加速度 + ArmSocket socket句柄 + joint_num 关节序号,1~7 + acc 关节转速,单位:°/s² + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Acc.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Acc.restype = self.check_error + + tag = self.pDll.Set_Joint_Acc(self.nSocket, joint_num, acc, block) + + logger_.info(f'Set_Joint_Acc:{tag}') + + return tag + + def Set_Joint_Min_Pos(self, joint_num, joint, block=True): + """ + Set_Joint_Min_Pos 设置关节最小限位 + ArmSocket socket句柄 + joint_num 关节序号,1~7 + joint 关节最小位置,单位:° + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Min_Pos.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Min_Pos.restype = self.check_error + + tag = self.pDll.Set_Joint_Min_Pos(self.nSocket, joint_num, joint, block) + + logger_.info(f'Set_Joint_Min_Pos:{tag}') + + return tag + + def Set_Joint_Max_Pos(self, joint_num, joint, block=True): + """ + Set_Joint_Max_Pos 设置关节最大限位 + ArmSocket socket句柄 + joint_num 关节序号,1~7 + joint 关节最小位置,单位:° + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Max_Pos.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Max_Pos.restype = self.check_error + + tag = self.pDll.Set_Joint_Max_Pos(self.nSocket, joint_num, joint, block) + + logger_.info(f'Set_Joint_Max_Pos:{tag}') + + return tag + + def Set_Joint_Drive_Speed(self, joint_num, speed, block=True): + """ + Set_Joint_Drive_Speed 设置关节最大速度(驱动器) + ArmSocket socket句柄 + joint_num 关节序号,1~7 + speed 关节转速,单位:°/s + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Joint_Drive_Speed.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Drive_Speed.restype = self.check_error + + tag = self.pDll.Set_Joint_Drive_Speed(self.nSocket, joint_num, speed, block) + + logger_.info(f'Set_Joint_Drive_Speed:{tag}') + + return tag + + def Set_Joint_Drive_Acc(self, joint_num, acc, block=True): + """ + Set_Joint_Drive_Acc 设置关节最大加速度(驱动器) + ArmSocket socket句柄 + joint_num 关节序号,1~7 + acc 关节转速,单位:°/s² + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Drive_Acc.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Drive_Acc.restype = self.check_error + + tag = self.pDll.Set_Joint_Drive_Acc(self.nSocket, joint_num, acc, block) + + logger_.info(f'Set_Joint_Drive_Acc:{tag}') + + return tag + + def Set_Joint_Drive_Min_Pos(self, joint_num, joint, block=True): + """ + Set_Joint_Drive_Min_Pos 设置关节最小限位(驱动器) + ArmSocket socket句柄 + joint_num 关节序号,1~7 + joint 关节最小位置,单位:° + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Drive_Min_Pos.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Drive_Min_Pos.restype = self.check_error + + tag = self.pDll.Set_Joint_Drive_Min_Pos(self.nSocket, joint_num, joint, block) + + logger_.info(f'Set_Joint_Drive_Min_Pos:{tag}') + + return tag + + def Set_Joint_Drive_Max_Pos(self, joint_num, joint, block=True): + """ + Set_Joint_Drive_Max_Pos 设置关节最大限位(驱动器) + ArmSocket socket句柄 + joint_num 关节序号,1~7 + joint 关节最小位置,单位:° + block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Drive_Max_Pos.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Joint_Drive_Max_Pos.restype = self.check_error + + tag = self.pDll.Set_Joint_Drive_Max_Pos(self.nSocket, joint_num, joint, block) + + logger_.info(f'Set_Joint_Drive_Max_Pos:{tag}') + + return tag + + def Set_Joint_EN_State(self, joint_num, state, block=True): + """ + Set_Joint_EN_State 设置关节使能状态 + :param joint_num: 关节序号,1~7 + :param state: true-上使能,false-掉使能 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + self.pDll.Set_Joint_EN_State.astypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool, ctypes.c_bool) + self.pDll.restype = self.check_error + + tag = self.pDll.Set_Joint_EN_State(self.nSocket, joint_num, state, block) + + logger_.info(f'Set_Joint_EN_State:{tag}') + + return tag + + def Set_Joint_Zero_Pos(self, joint_num, block): + """ + Set_Joint_Zero_Pos 将当前位置设置为关节零位 + :param joint_num: 关节序号,1~7 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Zero_Pos.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool) + self.pDll.Set_Joint_Zero_Pos.restype = self.check_error + + tag = self.pDll.Set_Joint_Zero_Pos(self.nSocket, joint_num, block) + + logger_.info(f'Set_Joint_Zero_Pos:{tag}') + + return tag + + def Set_Joint_Err_Clear(self, joint_num, block=True): + """ + Set_Joint_Err_Clear 清楚关节错误 + :param joint_num: 关节序号,1~7 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Joint_Err_Clear.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool) + self.pDll.Set_Joint_Err_Clear.restype = self.check_error + + tag = self.pDll.Set_Joint_Err_Clear(self.nSocket, joint_num, block) + + logger_.info(f'Set_Joint_Err_Clear:{tag}') + + return tag + + def Auto_Set_Joint_Limit(self, limit_mode): + """ + Auto_Set_Joint_Limit 自动设置关节限位 + :param limit_mode: 设置类型,1-正式模式,各关节限位为规格参数中的软限位和硬限位 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Auto_Set_Joint_Limit.argtypes = (ctypes.c_int, ctypes.c_byte) + self.pDll.Auto_Set_Joint_Limit.restype = self.check_error + + tag = self.pDll.Auto_Set_Joint_Limit(self.nSocket, limit_mode) + + logger_.info(f'Auto_Set_Joint_Limit:{tag}') + + return tag + + def Auto_Fix_Joint_Over_Soft_Limit(self, block=True): + """ + Auto_Fix_Joint_Over_Soft_Limit 超出限位后,自动运动到限位内 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Auto_Fix_Joint_Over_Soft_Limit.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Auto_Fix_Joint_Over_Soft_Limit.restype = self.check_error + + tag = self.pDll.Auto_Fix_Joint_Over_Soft_Limit(self.nSocket, block) + + logger_.info(f'Auto_Fix_Joint_Over_Soft_Limit:{tag}') + + return tag + + +class Get_Joint(): + + def Get_Joint_Speed(self, retry=0): + """ + Get_Joint_Speed 查询关节最大速度 + :return: + """ + le = self.code + speed = (ctypes.c_float * le)() # 关节1~7转速数组,单位:°/s + tag = self.pDll.Get_Joint_Speed(self.nSocket, speed) + + while tag and retry: + logger_.info(f'Get_Joint_Speed:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Speed(self.nSocket, speed) + retry -= 1 + + logger_.info(f'Get_Joint_Speed:{tag}') + + return tag, list(speed) + + def Get_Joint_Acc(self, retry=0): + + """ + Get_Joint_Acc 查询关节最大加速度 + :return: + """ + le = self.code + acc = (ctypes.c_float * le)() # 关节1~7加速度数组,单位:°/s² + tag = self.pDll.Get_Joint_Acc(self.nSocket, acc) + + while tag and retry: + logger_.info(f'Get_Joint_Acc:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Acc(self.nSocket, acc) + retry -= 1 + + logger_.info(f'Get_Joint_Acc:{tag}') + + return tag, list(acc) + + def Get_Joint_Min_Pos(self, retry=0): + + """ + Get_Joint_Min_Pos 获取关节最小限位 + :return: + """ + le = self.code + min_joint = (ctypes.c_float * le)() # 关节1~7最小位置数组,单位:° + tag = self.pDll.Get_Joint_Min_Pos(self.nSocket, min_joint) + + while tag and retry: + logger_.info(f'Get_Joint_Min_Pos:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Min_Pos(self.nSocket, min_joint) + retry -= 1 + + logger_.info(f'Get_Joint_Min_Pos:{tag}') + + return tag, list(min_joint) + + def Get_Joint_Max_Pos(self, retry=0): + + """ + Get_Joint_Max_Pos 获取关节最大限位 + :return: + + """ + le = self.code + max_joint = (ctypes.c_float * le)() # 关节1~7最大位置数组,单位:° + tag = self.pDll.Get_Joint_Max_Pos(self.nSocket, max_joint) + + while tag and retry: + logger_.info(f'Get_Joint_Max_Pos:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Max_Pos(self.nSocket, max_joint) + retry -= 1 + + logger_.info(f'Get_Joint_Max_Pos:{tag}') + + return tag, list(max_joint) + + def Get_Joint_Drive_Speed(self, retry=0): + """ + Get_Joint_Drive_Speed 查询关节最大速度(驱动器) + :return: + """ + le = self.code + speed = (ctypes.c_float * le)() # 关节1~7转速数组,单位:°/s + tag = self.pDll.Get_Joint_Drive_Speed(self.nSocket, speed) + + while tag and retry: + logger_.info(f'Get_Joint_Drive_Speed:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Drive_Speed(self.nSocket, speed) + retry -= 1 + + logger_.info(f'Get_Joint_Drive_Speed:{tag}') + + return tag, list(speed) + + def Get_Joint_Drive_Acc(self, retry=0): + + """ + Get_Joint_Drive_Acc 查询关节最大加速度(驱动器) + :return: + """ + le = self.code + acc = (ctypes.c_float * le)() # 关节1~7加速度数组,单位:°/s² + tag = self.pDll.Get_Joint_Drive_Acc(self.nSocket, acc) + + while tag and retry: + logger_.info(f'Get_Joint_Drive_Acc:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Drive_Acc(self.nSocket, acc) + retry -= 1 + + logger_.info(f'Get_Joint_Drive_Acc:{tag}') + + return tag, list(acc) + + def Get_Joint_Drive_Min_Pos(self, retry=0): + + """ + Get_Joint_Drive_Min_Pos 获取关节最小限位(驱动器) + :return: + """ + le = self.code + min_joint = (ctypes.c_float * le)() # 关节1~7最小位置数组,单位:° + tag = self.pDll.Get_Joint_Drive_Min_Pos(self.nSocket, min_joint) + + while tag and retry: + logger_.info(f'Get_Joint_Drive_Min_Pos:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Drive_Min_Pos(self.nSocket, min_joint) + retry -= 1 + + logger_.info(f'Get_Joint_Drive_Min_Pos:{tag}') + + return tag, list(min_joint) + + def Get_Joint_Drive_Max_Pos(self, retry=0): + + """ + Get_Joint_Drive_Max_Pos 获取关节最大限位(驱动器) + :return: + + """ + le = self.code + max_joint = (ctypes.c_float * le)() # 关节1~7最大位置数组,单位:° + tag = self.pDll.Get_Joint_Drive_Max_Pos(self.nSocket, max_joint) + + while tag and retry: + logger_.info(f'Get_Joint_Drive_Max_Pos:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Drive_Max_Pos(self.nSocket, max_joint) + retry -= 1 + + logger_.info(f'Get_Joint_Drive_Max_Pos:{tag}') + + return tag, list(max_joint) + + def Get_Joint_EN_State(self, retry=0): + """ + Get_Joint_EN_State 获取关节使能状态 + :return: + """ + le = self.code + state = (ctypes.c_ubyte * le)() # 关节1~7使能状态数组,1-使能状态,0-掉使能状态 + tag = self.pDll.Get_Joint_EN_State(self.nSocket, state) + + while retry: + logger_.info(f'Get_Joint_EN_State:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_EN_State(self.nSocket, state) + retry -= 1 + + logger_.info(f'Get_Joint_EN_State:{tag}') + return tag, list(state) + + def Get_Joint_Err_Flag(self, retry=0): + """ + Get_Joint_Err_Flag 获取关节Err Flag + :return:state 存放关节错误码(请参考api文档中的关节错误码) + bstate 关节抱闸状态(1代表抱闸未打开,0代表抱闸已打开) + """ + # le = int(str(self.code)[0]) + le = self.code + + state = (ctypes.c_uint16 * le)() + bstate = (ctypes.c_uint16 * le)() + + tag = self.pDll.Get_Joint_Err_Flag(self.nSocket, state, bstate) + + while tag and retry: + logger_.info(f'Get_Joint_Err_Flag:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Joint_Err_Flag(self.nSocket, state, bstate) + retry -= 1 + + logger_.info(f'Get_Joint_Err_Flag:{tag}') + return tag, list(state), list(bstate) + + def Get_Tool_Software_Version(self): + + """ + Get_Tool_Software_Version 查询末端接口板软件版本号 + :return: + """ + version = ctypes.c_int() + tag = self.pDll.Get_Tool_Software_Version(self.nSocket, ctypes.byref(version)) + + logger_.info(f'Get_Tool_Software_Version:{tag}') + return tag, hex(version.value) + + def Get_Joint_Software_Version(self): + + """ + Get_Joint_Software_Version 查询关节软件版本号 + :return: 关节软件版本号 + """ + + if self.code == 6: + self.pDll.Get_Joint_Software_Version.argtypes = (ctypes.c_int, ctypes.c_int * 6) + self.pDll.Get_Joint_Software_Version.restype = self.check_error + + version = (ctypes.c_int * 6)() + + else: + self.pDll.Get_Joint_Software_Version.argtypes = (ctypes.c_int, ctypes.c_int * 7) + self.pDll.Get_Joint_Software_Version.restype = self.check_error + + version = (ctypes.c_int * 7)() + + tag = self.pDll.Get_Joint_Software_Version(self.nSocket, version) + + return tag, [hex(i) for i in version] + + +class Tcp_Config(): + def Set_Arm_Line_Speed(self, speed, block=True): + + """ + Set_Arm_Line_Speed 设置机械臂末端最大线速度 + :param speed: 末端最大线速度,单位m/s + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Arm_Line_Speed.argtypes = (ctypes.c_int, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Arm_Line_Speed.restype = self.check_error + + tag = self.pDll.Set_Arm_Line_Speed(self.nSocket, speed, block) + + logger_.info(f'Set_Arm_Line_Speed:{tag}') + + return tag + + def Set_Arm_Line_Acc(self, acc, block=True): + """ + Set_Arm_Line_Acc 设置机械臂末端最大线加速度 + :param acc: 末端最大线加速度,单位m/s^2 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Arm_Line_Acc.argtypes = (ctypes.c_int, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Arm_Line_Acc.restype = self.check_error + + tag = self.pDll.Set_Arm_Line_Acc(self.nSocket, acc, block) + + logger_.info(f'Set_Arm_Line_Acc: {tag}') + + return tag + + def Set_Arm_Angular_Speed(self, speed, block=True): + """ + Set_Arm_Angular_Speed 设置机械臂末端最大角速度 + :param speed: 末端最大角速度,单位rad/s + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Arm_Angular_Speed.argtypes = (ctypes.c_int, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Arm_Angular_Speed.restype = self.check_error + + tag = self.pDll.Set_Arm_Angular_Speed(self.nSocket, speed, block) + + logger_.info(f'Set_Arm_Angular_Speed: {tag}') + + return tag + + def Set_Arm_Angular_Acc(self, acc, block=True): + """ + Set_Arm_Angular_Acc 设置机械臂末端最大角加速度 + :param acc: 末端最大角加速度,单位rad/s^2 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Arm_Angular_Acc.argtypes = (ctypes.c_int, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Arm_Angular_Acc.restype = self.check_error + + tag = self.pDll.Set_Arm_Angular_Acc(self.nSocket, acc, block) + + logger_.info(f'Set_Arm_Angular_Acc: {tag}') + + return tag + + def Get_Arm_Line_Speed(self, retry=0): + """ + Get_Arm_Line_Speed 获取机械臂末端最大线速度 + :return: + """ + + speed = ctypes.c_float() + speed_u = ctypes.pointer(speed) + + tag = self.pDll.Get_Arm_Line_Speed(self.nSocket, speed_u) + while tag and retry: + logger_.info(f'Get_Arm_Line_Speed:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Line_Speed(self.nSocket, speed_u) + retry -= 1 + + logger_.info(f'Get_Arm_Line_Speed:{tag}') + return tag, speed.value + + def Get_Arm_Line_Acc(self, retry=0): + """ + Get_Arm_Line_Acc 获取机械臂末端最大线加速度 + :return: + """ + + acc = ctypes.c_float() + acc_u = ctypes.pointer(acc) + + tag = self.pDll.Get_Arm_Line_Acc(self.nSocket, acc_u) + + while tag and retry: + logger_.info(f'Get_Arm_Line_Acc:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Line_Acc(self.nSocket, acc_u) + retry -= 1 + + logger_.info(f'Get_Arm_Line_Acc:{tag}') + return tag, acc.value + + def Get_Arm_Angular_Speed(self, retry=0): + """ + Get_Arm_Angular_Speed 获取机械臂末端最大角速度 + :return: + """ + + speed = ctypes.c_float() + speed_u = ctypes.pointer(speed) + + tag = self.pDll.Get_Arm_Angular_Speed(self.nSocket, speed_u) + + while tag and retry: + logger_.info(f'Get_Arm_Angular_Speed:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Angular_Speed(self.nSocket, speed_u) + retry -= 1 + + logger_.info(f'Get_Arm_Angular_Speed:{tag}') + return tag, speed.value + + def Get_Arm_Angular_Acc(self, retry=0): + """ + Get_Arm_Angular_Acc 获取机械臂末端最大角加速度 + :return: + """ + + acc = ctypes.c_float() + acc_u = ctypes.pointer(acc) + + tag = self.pDll.Get_Arm_Angular_Acc(self.nSocket, acc_u) + + while tag and retry: + logger_.info(f'Get_Arm_Angular_Acc:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Angular_Acc(self.nSocket, acc_u) + retry -= 1 + + logger_.info(f'Get_Arm_Angular_Acc:{tag}') + return tag, acc.value + + def Set_Arm_Tip_Init(self): + # 设置机械臂末端参数为初始值 + tag = self.pDll.Set_Arm_Tip_Init(self.nSocket, 1) + + logger_.info(f'Set_Arm_Tip_Init:{tag}') + logger_.info(f'设置机械臂末端参数为初始值') + + return tag + + def Set_Collision_Stage(self, stage, block=True): + """ + Set_Collision_Stage 设置机械臂动力学碰撞检测等级 + :param stage: 等级:0~8,0-无碰撞,8-碰撞最灵敏 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Collision_Stage.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Collision_Stage.restype = self.check_error + + tag = self.pDll.Set_Collision_Stage(self.nSocket, stage, block) + + logger_.info(f'Set_Collision_Stage:{tag}') + + return tag + + def Get_Collision_Stage(self, retry=0): + """ + Get_Collision_Stage 查询碰撞防护等级 + :return: 碰撞防护等级 + """ + self.pDll.Get_Collision_Stage.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)] + self.pDll.Get_Collision_Stage.restype = self.check_error + + stage = ctypes.c_int() + stage_u = ctypes.pointer(stage) + + tag = self.pDll.Get_Collision_Stage(self.nSocket, stage_u) + + while tag and retry: + logger_.info(f'Get_Collision_Stage:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Collision_Stage(self.nSocket, stage_u) + retry -= 1 + + logger_.info(f'防撞等级是:{stage.value}') + + logger_.info(f'Get_Collision_Stage:{tag}') + + return tag, stage.value + + def Set_Joint_Zero_Offset(self, offset, block=True): + """ + Set_Joint_Zero_Offset 该函数用于设置机械臂各关节零位补偿角度,一般在对机械臂零位进行标定后调用该函数 + :param offset: 关节1~6的零位补偿角度数组, 单位:度 + :param block: block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + le = self.code + self.pDll.Set_Joint_Zero_Offset.argtypes = [ctypes.c_void_p, ctypes.c_float * le, ctypes.c_bool] + self.pDll.Set_Joint_Zero_Offset.restype = self.check_error + + offset_arr = (ctypes.c_float * le)(*offset) + + tag = self.pDll.Set_Joint_Zero_Offset(self.nSocket, offset_arr, block) + + logger_.info(f'Set_Joint_Zero_Offset:{tag}') + + return tag + + +class Tool_Frame(): + def Auto_Set_Tool_Frame(self, point_num, block=True): + """ + Auto_Set_Tool_Frame 六点法自动设置工具坐标系 标记点位 + :param point_num: 1~6代表6个标定点 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + self.pDll.Auto_Set_Tool_Frame.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool) + self.pDll.Auto_Set_Tool_Frame.restype = self.check_error + + tag = self.pDll.Auto_Set_Tool_Frame(self.nSocket, point_num, block) + + logger_.info(f'Auto_Set_Tool_Frame:{tag}') + + return tag + + def Generate_Auto_Tool_Frame(self, name, payload, x, y, z, block=True): + + """ + Generate_Auto_Tool_Frame 六点法自动设置工具坐标系 提交 + :param name: 工具坐标系名称,不能超过十个字节。 + :param payload: 新工具执行末端负载重量 单位kg + :param x: 新工具执行末端负载位置 位置x 单位mm + :param y: 新工具执行末端负载位置 位置y 单位mm + :param z: 新工具执行末端负载位置 位置z 单位mm + :param block: block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + self.pDll.Generate_Auto_Tool_Frame.argtypes = ( + ctypes.c_int, ctypes.c_char_p, ctypes.c_float, ctypes.c_float, ctypes.c_float, + ctypes.c_float, ctypes.c_bool) + self.pDll.Generate_Auto_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + tag = self.pDll.Generate_Auto_Tool_Frame(self.nSocket, name, payload, x, y, z, block) + + logger_.info(f'Generate_Auto_Tool_Frame:{tag}') + + return tag + + def Manual_Set_Tool_Frame(self, name, pose, payload, x, y, z, block=True): + + """ + Manual_Set_Tool_Frame 手动设置工具坐标系 + :param name: 工具坐标系名称,不能超过十个字节 + :param pose: 新工具执行末端相对于机械臂法兰中心的位姿 + :param payload: 新工具执行末端负载重量 单位kg + :param x: 新工具执行末端负载位置 位置x 单位m + :param y: 新工具执行末端负载位置 位置y 单位m + :param z: 新工具执行末端负载位置 位置z 单位m + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Manual_Set_Tool_Frame.argtypes = ( + ctypes.c_int, ctypes.c_char_p, Pose, ctypes.c_float, ctypes.c_float, ctypes.c_float + , ctypes.c_float, ctypes.c_bool) + self.pDll.Manual_Set_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + pose1 = Pose() + + pose1.position = Pos(*pose[:3]) + pose1.euler = Euler(*pose[3:]) + + tag = self.pDll.Manual_Set_Tool_Frame(self.nSocket, name, pose1, payload, x, y, z, block) + + logger_.info(f'Manual_Set_Tool_Frame:{tag}') + + return tag + + def Change_Tool_Frame(self, name, block=True): + """ + Change_Tool_Frame 切换当前工具坐标系 + :param name: 目标工具坐标系名称 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Change_Tool_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_bool) + self.pDll.Change_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + tag = self.pDll.Change_Tool_Frame(self.nSocket, name, block) + + logger_.info(f'Change_Tool_Frame:{tag}') + + return tag + + def Delete_Tool_Frame(self, name, block=True): + """ + Delete_Tool_Frame 删除指定工具坐标系 + :param name: 要删除的工具坐标系名称 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + 备注:删除坐标系后,机械臂将切换到机械臂法兰末端工具坐标系 + """ + + self.pDll.Delete_Tool_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_bool) + self.pDll.Delete_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + tag = self.pDll.Delete_Tool_Frame(self.nSocket, name, block) + + logger_.info(f'Delete_Tool_Frame:{tag}') + + return tag + + def Update_Tool_Frame(self, name, pose, payload, x, y, z): + + """ + Update_Tool_Frame 修改指定工具坐标系 + :param name: 要修改的工具坐标系名称 + :param pose: 更新执行末端相对于机械臂法兰中心的位姿 + :param payload: 更新新工具执行末端负载重量 单位kg + :param x: 更新工具执行末端负载位置 位置x 单位m + :param y: 更新工具执行末端负载位置 位置y 单位m + :param z: 更新工具执行末端负载位置 位置z 单位m + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Update_Tool_Frame.argtypes = ( + ctypes.c_int, ctypes.c_char_p, Pose, ctypes.c_float, ctypes.c_float, ctypes.c_float + , ctypes.c_float) + self.pDll.Update_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + pose1 = Pose() + + pose1.position = Pos(*pose[:3]) + pose1.euler = Euler(*pose[3:]) + + tag = self.pDll.Update_Tool_Frame(self.nSocket, name, pose1, payload, x, y, z) + + logger_.info(f'Update_Tool_Frame:{tag}') + + return tag + + def Set_Tool_Envelope(self, envelop_list: ToolEnvelopeList): + """ + Set_Tool_Envelope 设置工具坐标系的包络参数 + :param envelop_list: 包络参数列表,每个工具最多支持 5 个包络球,可以没有包络 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Tool_Envelope.argtypes = (ctypes.c_int, ctypes.POINTER(ToolEnvelopeList)) + self.pDll.Set_Tool_Envelope.restype = self.check_error + + # tel_list = ToolEnvelopeList() + + tag = self.pDll.Set_Tool_Envelope(self.nSocket, ctypes.pointer(envelop_list)) + + logger_.info(f'Set_Tool_Envelope:{tag}') + + return tag + + def Get_Tool_Envelope(self, tool_name) -> (int, dict): # type: ignore + """ + 获取指定工具坐标系的包络参数 + :param tool_name: 指定工具坐标系名称 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Get_Tool_Envelope.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ToolEnvelopeList)] + self.pDll.Get_Tool_Envelope.restype = self.check_error + + tool_name = tool_name.encode("utf-8") + tel_list = ToolEnvelopeList() + tag = self.pDll.Get_Tool_Envelope(self.nSocket, tool_name, ctypes.pointer(tel_list)) + logger_.info(f'Get_Tool_Envelope:{tag}') + + return tag, tel_list.to_output() + + def Get_Current_Tool_Frame(self, retry=0): + """ + Get_Current_Tool_Frame 获取当前工具坐标系 + :param tool:返回的坐标系 + :return: + """ + + self.pDll.Get_Current_Tool_Frame.argtypes = (ctypes.c_int, ctypes.POINTER(FRAME)) + self.pDll.Get_Current_Tool_Frame.restype = self.check_error + + frame = FRAME() + + tag = self.pDll.Get_Current_Tool_Frame(self.nSocket, ctypes.byref(frame)) + + while tag and retry: + logger_.info(f'Get_Current_Tool_Frame run failed :{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Current_Tool_Frame(self.nSocket, ctypes.byref(frame)) + + retry -= 1 + + logger_.info(f'Get_Current_Tool_Frame:{tag}') + + return tag, frame + + def Get_Given_Tool_Frame(self, name, retry=0): + """ + Get_Given_Tool_Frame 获取指定工具坐标系 + :param name:指定的工具名称 + :param tool:返回的工具参数 + :return: + """ + + self.pDll.Get_Given_Tool_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(FRAME)) + + self.pDll.Get_Given_Tool_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + frame = FRAME() + + tag = self.pDll.Get_Given_Tool_Frame(self.nSocket, name, ctypes.byref(frame)) + + while tag and retry: + logger_.info(f'Get_Given_Tool_Frame run failed :{tag},retry is :{6 - retry}') + + tag = self.pDll.Get_Given_Tool_Frame(self.nSocket, name, ctypes.byref(frame)) + + retry -= 1 + + logger_.info(f'Get_Given_Tool_Frame:{tag}') + + return tag, frame + + def Get_All_Tool_Frame(self, retry=0): + + """ + Get_All_Tool_Frame 获取所有工具坐标系名称 + :return: + """ + + self.pDll.Get_All_Tool_Frame.argtypes = (ctypes.c_int, ctypes.POINTER(FRAME_NAME), ctypes.POINTER(ctypes.c_int)) + + self.pDll.Get_All_Tool_Frame.restype = self.check_error + + max_len = 10 # maximum number of tools + + names = (FRAME_NAME * max_len)() # 创建 FRAME_NAME 数组 + names_ptr = ctypes.POINTER(FRAME_NAME)(names) # + + len_ = ctypes.c_int() + + tag = self.pDll.Get_All_Tool_Frame(self.nSocket, names_ptr, ctypes.byref(len_)) + + while tag and retry: + logger_.info(f'Get_All_Tool_Frame run failed :{tag},retry is :{6 - retry}') + tag = self.pDll.Get_All_Tool_Frame(self.nSocket, names_ptr, ctypes.byref(len_)) + retry -= 1 + + logger_.info(f'Get_All_Tool_Frame:{tag}') + + tool_names = [names[i].name.decode('utf-8') for i in range(len_.value)] + return tag, tool_names, len_.value + + +class Work_Frame(): + def Auto_Set_Work_Frame(self, name, point_num, block=True): + + """ + Auto_Set_Work_Frame 三点法自动设置工作坐标系 + :param name: 工作坐标系名称,不能超过十个字节。 + :param point_num: 1~3代表3个标定点,依次为原点、X轴一点、Y轴一点,4代表生成坐标系。 + :param block: 0-成功,失败返回:错误码, rm_define.h查询. + :return: + """ + + self.pDll.Auto_Set_Work_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_byte, ctypes.c_bool) + self.pDll.Auto_Set_Work_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + tag = self.pDll.Auto_Set_Work_Frame(self.nSocket, name, point_num, block) + + logger_.info(f'Auto_Set_Work_Frame:{tag}') + + return tag + + def Manual_Set_Work_Frame(self, name, pose, block=True): + """ + Manual_Set_Work_Frame 手动设置工作坐标系 + :param name: 工作坐标系名称,不能超过十个字节。 + :param pose: 新工作坐标系相对于基坐标系的位姿 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Manual_Set_Work_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, Pose, ctypes.c_bool) + self.pDll.Manual_Set_Work_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + pose1 = Pose() + + pose1.position = Pos(*pose[:3]) + pose1.euler = Euler(*pose[3:]) + + tag = self.pDll.Manual_Set_Work_Frame(self.nSocket, name, pose1, block) + + logger_.info(f'Manual_Set_Work_Fram:{tag}') + + return tag + + def Change_Work_Frame(self, name="Base"): + """ + 切换到某个工作坐标系,默认是base坐标系 + """ + + self.pDll.Change_Work_Frame.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_bool] + name = ctypes.c_char_p(name.encode('utf-8')) + tag = self.pDll.Change_Work_Frame(self.nSocket, name, 1) + logger_.info(f'Change_Work_Frame:{tag}') + time.sleep(1) + + return tag + + def Delete_Work_Frame(self, name, block=True): + """ + Delete_Work_Frame 删除指定工作坐标系 + :param name: 要删除的工具坐标系名称 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Delete_Work_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_bool) + + self.pDll.Delete_Work_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + tag = self.pDll.Delete_Work_Frame(self.nSocket, name, block) + + logger_.info(f'Delete_Work_Frame:{tag}') + + return tag + + def Update_Work_Frame(self, name, pose): + + """ + Update_Work_Frame 修改指定工作坐标系 + :param name: 要修改的工作坐标系名称 + :param pose: 更新工作坐标系相对于基坐标系的位姿 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Update_Work_Frame.argtypes = ( + ctypes.c_int, ctypes.c_char_p, Pose) + self.pDll.Update_Work_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + pose1 = Pose() + + pose1.position = Pos(*pose[:3]) + pose1.euler = Euler(*pose[3:]) + + tag = self.pDll.Update_Work_Frame(self.nSocket, name, pose1) + + logger_.info(f'Update_Work_Frame:{tag}') + + return tag + + def Get_Current_Work_Frame(self, retry=0): + """ + Get_Current_Work_Frame 获取当前工作坐标系 + :return: + """ + + self.pDll.Get_Current_Work_Frame.argtypes = (ctypes.c_int, ctypes.POINTER(FRAME)) + + self.pDll.Get_Current_Work_Frame.restype = self.check_error + + frame = FRAME() + + tag = self.pDll.Get_Current_Work_Frame(self.nSocket, ctypes.byref(frame)) + + while tag and retry: + logger_.info(f'Get_Current_Work_Frame run failed :{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Current_Work_Frame(self.nSocket, ctypes.byref(frame)) + + retry -= 1 + + logger_.info(f'Get_Current_Work_Frame:{tag}') + + return tag, frame + + def Get_Given_Work_Frame(self, name, retry=0): + """ + Get_Given_Work_Frame 获取指定工作坐标系 + :return:指定工作坐标系得位姿 + """ + + self.pDll.Get_Given_Work_Frame.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(Pose)) + self.pDll.Get_Given_Work_Frame.restype = self.check_error + + name = ctypes.c_char_p(name.encode('utf-8')) + + pose = Pose() + + tag = self.pDll.Get_Given_Work_Frame(self.nSocket, name, ctypes.byref(pose)) + + while tag and retry: + logger_.info(f'Get_Given_Work_Frame run failed :{tag},retry is :{6 - retry}') + + tag = self.pDll.Get_Given_Work_Frame(self.nSocket, name, ctypes.byref(pose)) + + retry -= 1 + + logger_.info(f'Get_Given_Work_Frame:{tag}') + + position = pose.position + euler = pose.euler + return tag, [position.x, position.y, position.z, euler.rx, euler.ry, euler.rz] + + def Get_All_Work_Frame(self, retry=0): + """ + Get_All_Work_Frame 获取所有工作坐标系名称 + :return: + """ + + self.pDll.Get_All_Work_Frame.argtypes = (ctypes.c_int, ctypes.POINTER(FRAME_NAME), ctypes.POINTER(ctypes.c_int)) + + max_len = 10 # maximum number of tools + names = (FRAME_NAME * max_len)() # creates an array of FRAME_NAME + names_ptr = ctypes.POINTER(FRAME_NAME)(names) # + len_ = ctypes.c_int() + + tag = self.pDll.Get_All_Work_Frame(self.nSocket, names_ptr, ctypes.byref(len_)) + + while tag and retry: + logger_.info(f'Get_All_Work_Frame run failed :{tag},retry is :{6 - retry}') + tag = self.pDll.Get_All_Work_Frame(self.nSocket, names_ptr, ctypes.byref(len_)) + retry -= 1 + + logger_.info(f'Get_All_Work_Frame:{tag}') + + job_names = [names[i].name.decode('utf-8') for i in range(len_.value)] + return tag, job_names, len_.value + + +class Arm_State(): + def Get_Current_Arm_State(self, retry=0): + """获取机械臂当前状态 + + :return (error_code,joints,curr_pose,arm_err,sys_err) + error_code 0-成功,失败返回:错误码, rm_define.h查询. + joint 关节角度数组 + pose 机械臂当前位姿数组 + arm_err 机械臂运行错误代码 + sys_err 控制器错误代码 + """ + + le = self.code + + self.pDll.Get_Current_Arm_State.argtypes = (ctypes.c_int, ctypes.c_float * le, ctypes.POINTER(Pose), + ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_uint16)) + self.pDll.Get_Current_Arm_State.restype = self.check_error + joints = (ctypes.c_float * le)() + curr_pose = Pose() + cp_ptr = ctypes.pointer(curr_pose) + arm_err_ptr = ctypes.pointer(ctypes.c_uint16()) + sys_err_ptr = ctypes.pointer(ctypes.c_uint16()) + error_code = self.pDll.Get_Current_Arm_State(self.nSocket, joints, cp_ptr, arm_err_ptr, sys_err_ptr) + while error_code and retry: + # sleep(0.3) + logger_.warning(f"Failed to get curr arm states. Error Code: {error_code}\tRetry Count: {retry}") + error_code = self.pDll.Get_Current_Arm_State(self.nSocket, joints, cp_ptr, arm_err_ptr, sys_err_ptr) + retry -= 1 + + logger_.info(f'Get_Current_Arm_State:{error_code}') + + position = curr_pose.position + euler = curr_pose.euler + curr_pose = [position.x, position.y, position.z, euler.rx, euler.ry, euler.rz] + return error_code, list(joints), curr_pose, arm_err_ptr.contents.value, sys_err_ptr.contents.value + + def Get_Joint_Temperature(self): + """ + Get_Joint_Temperature 获取关节当前温度 + :return:(error_code,temperature) + error_code 0-成功,失败返回:错误码, rm_define.h查询. + temperature 关节温度数组 + """ + + le = self.code + + self.pDll.Get_Joint_Temperature.argtypes = (ctypes.c_int, ctypes.c_float * le) + + self.pDll.Get_Joint_Temperature.restype = self.check_error + + temperature = (ctypes.c_float * le)() + + tag = self.pDll.Get_Joint_Temperature(self.nSocket, temperature) + + logger_.info(f'Get_Joint_Temperature:{tag}') + + return tag, list(temperature) + + def Get_Joint_Current(self): + """ + Get_Joint_Current 获取关节当前电流 + :return:(error_code,current) + error_code 0-成功,失败返回:错误码, rm_define.h查询. + current 关节电流数组 + """ + le = self.code + + self.pDll.Get_Joint_Current.argtypes = (ctypes.c_int, ctypes.c_float * le) + + self.pDll.Get_Joint_Current.restype = self.check_error + + current = (ctypes.c_float * le)() + + tag = self.pDll.Get_Joint_Current(self.nSocket, current) + + logger_.info(f'Get_Joint_Current:{tag}') + + return tag, list(current) + + def Get_Joint_Voltage(self): + """ + Get_Joint_Voltage 获取关节当前电压 + :return:(error_code,voltage) + error_code 0-成功,失败返回:错误码, rm_define.h查询. + voltage 关节电压数组 + """ + le = self.code + + self.pDll.Get_Joint_Voltage.argtypes = (ctypes.c_int, ctypes.c_float * le) + + self.pDll.Get_Joint_Voltage.restype = self.check_error + + voltage = (ctypes.c_float * le)() + + tag = self.pDll.Get_Joint_Voltage(self.nSocket, voltage) + + logger_.info(f'Get_Joint_Voltage:{tag}') + + return tag, list(voltage) + + def Get_Joint_Degree(self): + """ + Get_Joint_Degree 获取关节当前电压 + :return:(error_code,joint) + error_code 0-成功,失败返回:错误码, rm_define.h查询. + joint 关节角度数组 + """ + + self.pDll.Get_Joint_Degree.argtypes = (ctypes.c_int, ctypes.c_float * 7) + + self.pDll.Get_Joint_Degree.restype = self.check_error + + joint = (ctypes.c_float * 7)() + + tag = self.pDll.Get_Joint_Degree(self.nSocket, joint) + + logger_.info(f'Get_Joint_Degree:{tag}') + + return tag, list(joint) + + def Get_Arm_All_State(self, retry=0) -> (int, JOINT_STATE): # type: ignore + """ + Get_Arm_All_State 获取机械臂所有状态信息 + :return: + """ + self.pDll.Get_Arm_All_State.argtypes = (ctypes.c_int, ctypes.POINTER(JOINT_STATE)) + self.pDll.Get_Arm_All_State.restype = self.check_error + + joint_status = JOINT_STATE() + + # joint_status_p = ctypes.pointer(joint_status) + tag = self.pDll.Get_Arm_All_State(self.nSocket, joint_status) + + while tag and retry: + logger_.info(f'Get_Arm_All_State:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_All_State(self.nSocket, joint_status) + retry -= 1 + + logger_.info(f'Get_Arm_All_State:{tag}') + + return tag, joint_status + + def Get_Arm_Plan_Num(self, retry=0): + + """ + Get_Arm_Plan_Num 查询规划计数 + :return: + """ + + self.pDll.Get_Arm_Plan_Num.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + self.pDll.Get_Arm_Plan_Num.restype = self.check_error + + plan_num = ctypes.c_int() + plan_num_p = ctypes.pointer(plan_num) + + tag = self.pDll.Get_Arm_Plan_Num(self.nSocket, plan_num_p) + + while tag and retry: + logger_.info(f'Get_Arm_Plan_Num:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Plan_Num(self.nSocket, plan_num_p) + + retry -= 1 + + logger_.info(f'Get_Arm_Plan_Num:{tag}') + + return tag, plan_num.value + + +class Initial_Pose(): + def Set_Arm_Init_Pose(self, target, block=True): + """ + Set_Arm_Init_Pose 设置机械臂的初始位置角度 + :param target: 机械臂初始位置关节角度数组 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + if self.code == 6: + self.pDll.Set_Arm_Init_Pose.argtypes = (ctypes.c_int, ctypes.c_float * 6, ctypes.c_bool) + self.pDll.Set_Arm_Init_Pose.restype = self.check_error + + target = (ctypes.c_float * 6)(*target) + + tag = self.pDll.Set_Arm_Init_Pose(self.nSocket, target, block) + + else: + self.pDll.Set_Arm_Init_Pose.argtypes = (ctypes.c_int, ctypes.c_float * 7, ctypes.c_bool) + self.pDll.Set_Arm_Init_Pose.restype = self.check_error + + target = (ctypes.c_float * 7)(*target) + + tag = self.pDll.Set_Arm_Init_Pose(self.nSocket, target, block) + + logger_.info(f'Set_Arm_Init_Pose:{tag}') + return tag + + def Get_Arm_Init_Pose(self): + """ + Set_Arm_Init_Pose 获取机械臂初始位置角度 + :return:joint 机械臂初始位置关节角度数组 + """ + + if self.code == 6: + self.pDll.Get_Arm_Init_Pose.argtypes = (ctypes.c_int, ctypes.c_float * 6) + self.pDll.Get_Arm_Init_Pose.restype = self.check_error + + target = (ctypes.c_float * 6)() + + tag = self.pDll.Get_Arm_Init_Pose(self.nSocket, target) + + else: + self.pDll.Get_Arm_Init_Pose.argtypes = (ctypes.c_int, ctypes.c_float * 7) + self.pDll.Get_Arm_Init_Pose.restype = self.check_error + + target = (ctypes.c_float * 7)() + + tag = self.pDll.Get_Arm_Init_Pose(self.nSocket, target) + + logger_.info(f'Get_Arm_Init_Pose:{tag}') + + return tag, list(target) + + def Set_Install_Pose(self, x, y, z, block=True): + """ + Set_Install_Pose 设置安装方式参数 + + :param x: 旋转角 单位 ° + :param y: 俯仰角 单位 ° + :param z: 方位角 单位 ° + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Install_Pose.argtypes = ( + ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_Install_Pose.restype = self.check_error + + tag = self.pDll.Set_Install_Pose(self.nSocket, x, y, z, block) + + logger_.info(f'Set_Install_Pose:{tag}') + + return tag + + def Get_Install_Pose(self): + """ + Get_Install_Pose 获取安装方式参数 + + err_code: 0-成功,失败返回:错误码, rm_define.h查询. + x: 旋转角 单位 ° + y: 俯仰角 单位 ° + z: 方位角 单位 ° + :return:(err_code,x,y,z) + """ + self.pDll.Get_Install_Pose.argtypes = (ctypes.c_int, + ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_float)) + x = ctypes.c_float() + y = ctypes.c_float() + z = ctypes.c_float() + tag = self.pDll.Get_Install_Pose(self.nSocket, x, y, z) + logger_.info(f'Get_Install_Pose:{tag}') + + return tag, x.value, y.value, z.value + + +class Move_Plan: + def Movej_Cmd(self, joint, v, trajectory_connect=0, r=0, block=True): + """ + Movej_Cmd 关节空间运动 + ArmSocket socket句柄 + joint 目标关节1~7角度数组 + v 速度比例1~100,即规划速度和加速度占关节最大线转速和加速度的百分比 + r 轨迹交融半径,目前默认0。 + trajectory_connect 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + block True 阻塞 False 非阻塞 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + le = self.code + float_joint = ctypes.c_float * le + joint = float_joint(*joint) + self.pDll.Movej_Cmd.argtypes = (ctypes.c_int, ctypes.c_float * le, ctypes.c_byte, + ctypes.c_float, ctypes.c_int, ctypes.c_bool) + + self.pDll.Movej_Cmd.restype = self.check_error + + tag = self.pDll.Movej_Cmd(self.nSocket, joint, v, r, trajectory_connect, block) + logger_.info(f'Movej_Cmd:{tag}') + + return tag + + def Movel_Cmd(self, pose, v, trajectory_connect=0, r=0, block=True): + """ + 笛卡尔空间直线运动 + + pose 目标位姿,位置单位:米,姿态单位:弧度 + v 速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + trajectory_connect 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + r 轨迹交融半径,目前默认0。 + block True 阻塞 False 非阻塞 + + return:0-成功,失败返回:错误码, rm_define.h查询 + """ + + po1 = Pose() + po1.position = Pos(*pose[:3]) + po1.euler = Euler(*pose[3:]) + + self.pDll.Movel_Cmd.argtypes = (ctypes.c_int, Pose, ctypes.c_byte, ctypes.c_float, ctypes.c_int, ctypes.c_bool) + self.pDll.Movel_Cmd.restype = self.check_error + tag = self.pDll.Movel_Cmd(self.nSocket, po1, v, r, trajectory_connect, block) + logger_.info(f'Movel_Cmd:{tag}') + + return tag + + def Movec_Cmd(self, pose_via, pose_to, v, loop, trajectory_connect=0, r=0, block=True): + """ + Movec_Cmd 笛卡尔空间圆弧运动 + :param pose_via: 中间点位姿,位置单位:米,姿态单位:弧度 + :param pose_to: 终点位姿 + :param v: 速度比例1~100,即规划速度和加速度占机械臂末端最大角速度和角加速度的百分比 + :param trajectory_connect: 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + :param r: 轨迹交融半径,目前默认0。 + :param loop:规划圈数,目前默认0. + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待机械臂到达位置或者规划失败 + :return: + """ + + self.pDll.Movec_Cmd.argtypes = ( + ctypes.c_int, Pose, Pose, ctypes.c_byte, ctypes.c_float, ctypes.c_byte, ctypes.c_int, ctypes.c_bool) + self.pDll.Movec_Cmd.restype = self.check_error + + pose1 = Pose() + + pose1.position = Pos(*pose_via[:3]) + pose1.euler = Euler(*pose_via[3:]) + + pose2 = Pose() + + pose2.position = Pos(*pose_to[:3]) + pose2.euler = Euler(*pose_to[3:]) + + tag = self.pDll.Movec_Cmd(self.nSocket, pose1, pose2, v, r, loop, trajectory_connect, block) + + logger_.info(f'Movec_Cmd:{tag}') + + return tag + + def Movej_P_Cmd(self, pose, v, trajectory_connect=0, r=0, block=True): + """ + 该函数用于关节空间运动到目标位姿 + param ArmSocket socket句柄 + pose: 目标位姿,位置单位:米,姿态单位:弧度。 注意:目标位姿必须是机械臂当前工具坐标系相对于当前工作坐标系的位姿, + 用户在使用该指令前务必确保,否则目标位姿会出错!! + v: 速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + trajectory_connect: 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + r: 轨迹交融半径,目前默认0。 + block True 阻塞 False 非阻塞 + return 0-成功,失败返回:错误码 + + """ + po1 = Pose() + + po1.position = Pos(*pose[:3]) + po1.euler = Euler(*pose[3:]) + + self.pDll.Movej_P_Cmd.argtypes = ( + ctypes.c_int, Pose, ctypes.c_byte, ctypes.c_float, ctypes.c_int, ctypes.c_bool) + self.pDll.Movej_P_Cmd.restype = self.check_error + + tag = self.pDll.Movej_P_Cmd(self.nSocket, po1, v, r, trajectory_connect, block) + logger_.info(f'Movej_P_Cmd执行结果:{tag}') + + return tag + + def Moves_Cmd(self, pose, v, trajectory_connect=0, r=0, block=True): + """ + 该函数用于样条曲线运动, + :param ArmSocket socket句柄 + :param pose: 目标位姿,位置单位:米,姿态单位:弧度。 + :param v: 速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + :param trajectory_connect: 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行,样条曲线运动需至少连续下发三个点位,否则运动轨迹为直线 + :param r: 轨迹交融半径,目前默认0。 + :param block True 阻塞 False 非阻塞 + :return 0-成功,失败返回:错误码 + + """ + po1 = Pose() + + po1.position = Pos(*pose[:3]) + po1.euler = Euler(*pose[3:]) + + self.pDll.Moves_Cmd.argtypes = ( + ctypes.c_int, Pose, ctypes.c_byte, ctypes.c_float, ctypes.c_int, ctypes.c_bool) + self.pDll.Moves_Cmd.restype = self.check_error + + tag = self.pDll.Moves_Cmd(self.nSocket, po1, v, r, trajectory_connect, block) + logger_.info(f'Moves_Cmd执行结果:{tag}') + + return tag + + def Movej_CANFD(self, joint, follow, expand=0): + """ + Movej_CANFD 角度不经规划,直接通过CANFD透传给机械臂 + :param joint: 关节1~7目标角度数组 + :param follow: 是否高跟随 + 因此只要控制器运行正常并且目标角度在可达范围内,机械臂立即返回成功指令,此时机械臂可能仍在运行; + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + if self.code == 6: + + self.pDll.Movej_CANFD.argtypes = (ctypes.c_int, ctypes.c_float * 6, ctypes.c_bool, ctypes.c_float) + self.pDll.Movej_CANFD.restype = self.check_error + + joints = (ctypes.c_float * 6)(*joint) + + + else: + self.pDll.Movej_CANFD.argtypes = (ctypes.c_int, ctypes.c_float * 7, ctypes.c_bool, ctypes.c_float) + self.pDll.Movej_CANFD.restype = self.check_error + + joints = (ctypes.c_float * 7)(*joint) + + tag = self.pDll.Movej_CANFD(self.nSocket, joints, follow, expand) + + logger_.info(f'Movej_CANFD:{tag}') + + return tag + + def Movep_CANFD(self, pose, follow): + """ + Movep_CANFD 位资不经规划,直接通过CANFD透传给机械臂 + :param pose: 关节1~7目标角度数组 + :param follow: 是否高跟随 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + if len(pose) > 6: + po1 = Pose() + po1.position = Pos(*pose[:3]) + po1.quaternion = Quat(*pose[3:]) + else: + po1 = Pose() + po1.position = Pos(*pose[:3]) + po1.euler = Euler(*pose[3:]) + + self.pDll.Movep_CANFD.argtypes = (ctypes.c_int, Pose, ctypes.c_bool) + self.pDll.Movep_CANFD.restype = self.check_error + tag = self.pDll.Movep_CANFD(self.nSocket, po1, follow) + logger_.info(f'Movep_CANFD:{tag}') + + return tag + + def MoveRotate_Cmd(self, rotateAxis, rotateAngle, choose_axis, v, trajectory_connect=0, r=0, block=True): + + """ + MoveRotate_Cmd 计算环绕运动位姿并按照结果运动 + :param rotateAxis:旋转轴: 1:x轴, 2:y轴, 3:z轴 + :param rotateAngle:旋转角度: 旋转角度, 单位(度) + :param choose_axis:指定计算时使用的坐标系 + :param v:速度 + :param trajectory_connect:代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + :param r:交融半径 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.MoveRotate_Cmd.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_float, Pose, ctypes.c_byte, + ctypes.c_float, ctypes.c_int, ctypes.c_bool) + + self.pDll.MoveRotate_Cmd.restype = self.check_error + + pose = Pose() + + pose.position = Pos(*choose_axis[:3]) + pose.euler = Euler(*choose_axis[3:]) + + tag = self.pDll.MoveRotate_Cmd(self.nSocket, rotateAxis, rotateAngle, pose, v, r, trajectory_connect, block) + + logger_.info(f'MoveRotate_Cmd:{tag}') + + return tag + + def MoveCartesianTool_Cmd(self, joint_cur, movelengthx, movelengthy, movelengthz, m_dev, v, trajectory_connect=0, + r=0, + block=True): + """ + cartesian_tool 沿工具端位姿移动 + :param joint_cur: 当前关节角度 + :param movelengthx: 沿X轴移动长度,米为单位 + :param movelengthy: 沿Y轴移动长度,米为单位 + :param movelengthz: 沿Z轴移动长度,米为单位 + :param m_dev: 机械臂型号 + :param v: 速度 + :param trajectory_connect: 代表是否和下一条运动一起规划,0代表立即规划,1代表和下一条轨迹一起规划,当为1时,轨迹不会立即执行 + :param r: 交融半径 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回; RM_BLOCK-阻塞,等待机械臂到达位置或者规划失败 + :return: + """ + + if self.code == 6: + + self.pDll.MoveCartesianTool_Cmd.argtypes = ( + ctypes.c_int, ctypes.c_float * 6, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_int, + ctypes.c_byte, ctypes.c_float, ctypes.c_int, ctypes.c_bool) + self.pDll.MoveCartesianTool_Cmd.restype = self.check_error + + joints = (ctypes.c_float * 6)(*joint_cur) + + + else: + + self.pDll.MoveCartesianTool_Cmd.argtypes = ( + ctypes.c_int, ctypes.c_float * 7, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_int, + ctypes.c_byte, ctypes.c_float, ctypes.c_int, ctypes.c_bool) + self.pDll.MoveCartesianTool_Cmd.restype = self.check_error + + joints = (ctypes.c_float * 7)(*joint_cur) + + tag = self.pDll.MoveCartesianTool_Cmd(self.nSocket, joints, movelengthx, movelengthy, movelengthz, m_dev, v, r, + trajectory_connect, block) + + logger_.info(f'MoveCartesianTool_Cmd:{tag}') + + return tag + + def Get_Current_Trajectory(self) -> Tuple[int, int, List[float]]: + """ + Get_Current_Trajectory 获取当前轨迹规划类型 + + :return: + tuple[int, int, list[float]]: 一个包含三个元素的元组,分别表示: + - int: 0-成功,失败返回:错误码, errro_message查询.。 + - int: 轨迹规划类型(由 ARM_CTRL_MODES 枚举定义的值)。 + - list[float]: 包含7个浮点数的列表,关节规划及无规划时,该列表为关节角度数组;其他类型为末端位姿数组[x,y,z,rx,ry,rz]。 + """ + + self.pDll.Get_Current_Trajectory.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_int), + ctypes.c_float * 7] + self.pDll.Get_Current_Trajectory.restype = self.check_error + + type = ctypes.c_int() + data = (ctypes.c_float * 7)() + tag = self.pDll.Get_Current_Trajectory(self.nSocket, ctypes.byref(type), data) + + logger_.info(f'Get_Current_Trajectory result:{tag}') + return tag, type.value, list(data) + + def Move_Stop_Cmd(self, block=True): + + """ + Move_Stop_Cmd 突发状况 机械臂以最快速度急停,轨迹不可恢复 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Move_Stop_Cmd.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Move_Stop_Cmd.restype = self.check_error + + tag = self.pDll.Move_Stop_Cmd(self.nSocket, block) + + logger_.info(f'Move_Stop_Cmd:{tag}') + + return tag + + def Move_Pause_Cmd(self, block=True): + + """ + Move_Pause_Cmd 轨迹暂停,暂停在规划轨迹上,轨迹可恢复 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Move_Pause_Cmd.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Move_Pause_Cmd.restype = self.check_error + + tag = self.pDll.Move_Pause_Cmd(self.nSocket, block) + + logger_.info(f'Move_Pause_Cmd:{tag}') + + return tag + + def Move_Continue_Cmd(self, block=True): + + """ + Move_Continue_Cmd 轨迹暂停后,继续当前轨迹运动 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Move_Continue_Cmd.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Move_Continue_Cmd.restype = self.check_error + + tag = self.pDll.Move_Continue_Cmd(self.nSocket, block) + + logger_.info(f'Move_Continue_Cmd:{tag}') + + return tag + + def Clear_Current_Trajectory(self, block=True): + + """ + Clear_Current_Trajectory 清除当前轨迹,必须在暂停后使用,否则机械臂会发生意外!!!! + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_Current_Trajectory.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Clear_Current_Trajectory.restype = self.check_error + + tag = self.pDll.Clear_Current_Trajectory(self.nSocket, block) + + logger_.info(f'Clear_Current_Trajectory:{tag}') + + return tag + + def Clear_All_Trajectory(self, block=True): + + """ + Clear_All_Trajectory 清除所有轨迹,必须在暂停后使用,否则机械臂会发生意外!!!! + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_All_Trajectory.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Clear_All_Trajectory.restype = self.check_error + + tag = self.pDll.Clear_All_Trajectory(self.nSocket, block) + + logger_.info(f'Clear_All_Trajectory:{tag}') + + return tag + + +class Teaching: + def Joint_Teach_Cmd(self, num, direction, v, block=True): + """ + Joint_Teach_Cmd 关节示教 + :param num: 示教关节的序号,1~7 + :param direction: 示教方向,0-负方向,1-正方向 + :param v: 速度比例1~100,即规划速度和加速度占关节最大线转速和加速度的百分比 + :param block: + :return: + """ + + self.pDll.Joint_Teach_Cmd.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_byte, ctypes.c_byte, ctypes.c_bool) + self.pDll.Joint_Teach_Cmd.restype = self.check_error + + tag = self.pDll.Joint_Teach_Cmd(self.nSocket, num, direction, v, block) + + logger_.info(f'Joint_Teach_Cmd:{tag}') + + return tag + + def Joint_Step_Cmd(self, num, step, v, block=True): + + """ + Joint_Step_Cmd 关节步进 + :param num: 关节序号,1~7 + :param step: 步进的角度 + :param v: 速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待机械臂返回失败或者到达位置指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Joint_Step_Cmd.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_byte, ctypes.c_bool) + + self.pDll.Joint_Step_Cmd.restype = self.check_error + + tag = self.pDll.Joint_Step_Cmd(self.nSocket, num, step, v, block) + + logger_.info(f'Joint_Step_Cmd:{tag}') + + return tag + + def Ort_Step_Cmd(self, type, step, v, block=True): + + """ + Ort_Step_Cmd 当前工作坐标系下,姿态步进 + :param type:示教类型 0:RX 1:RY 2:RZ + :param step:步进的弧度,单位rad,精确到0.001rad + :param v:速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待机械臂返回失败或者到达位置指令 + :return: + """ + + self.pDll.Ort_Step_Cmd.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_float, ctypes.c_byte, ctypes.c_bool) + self.pDll.Ort_Step_Cmd.restype = self.check_error + + tag = self.pDll.Ort_Step_Cmd(self.nSocket, type, step, v, block) + + logger_.info(f'Ort_Step_Cmd:{tag}') + + return tag + + def Pos_Teach_Cmd(self, type, direction, v, block=True): + + """ + Pos_Teach_Cmd 当前工作坐标系下,笛卡尔空间位置示教 + :param type:示教类型 0:x轴方向 1:y轴方向 2:z轴方向 + :param direction:示教方向,0-负方向,1-正方向 + :param v:速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Pos_Teach_Cmd.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_byte, ctypes.c_byte, ctypes.c_bool) + self.pDll.Pos_Teach_Cmd.restype = self.check_error + + tag = self.pDll.Pos_Teach_Cmd(self.nSocket, type, direction, v, block) + + logger_.info(f'Pos_Teach_Cmd:{tag}') + + return tag + + def Pos_Step_Cmd(self, type_, step, v, block=True): + + """ + Pos_Step_Cmd 当前工作坐标系下,位置步进 + ArmSocket socket句柄 + type 示教类型 x:0 y:1 z:2 + step 步进的距离,单位m,精确到0.001mm + v 速度比例1~100,即规划速度和加速度占机械臂末端最大线速度和线加速度的百分比 + block RM_NONBLOCK-非阻塞,发送后立即返回; RM_BLOCK-阻塞,等待机械臂返回失败或者到达位置指令 + + + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + if type_ == 0: + type_ = POS_TEACH_MODES.X_Dir + elif type_ == 1: + type_ = POS_TEACH_MODES.Y_Dir + elif type_ == 2: + type_ = POS_TEACH_MODES.Z_Dir + + self.pDll.Pos_Step_Cmd.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_float, ctypes.c_byte, ctypes.c_bool) + self.pDll.Pos_Step_Cmd.restype = self.check_error + tag = self.pDll.Pos_Step_Cmd(self.nSocket, type_, step, v, block) + logger_.info(f'Pos_Step_Cmd: {tag}') + return tag + + def Ort_Teach_Cmd(self, type, direction, v, block=True): + """ + + :param type: + 0, // RX轴方向 + 1, // RY轴方向 + 2, // RZ轴方向 + :param direction: 示教方向,0-负方向,1-正方向 + :param v: 速度比例1~100,即规划速度和加速度占机械臂末端最大角速度和角加速度的百分比 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + self.pDll.Ort_Teach_Cmd.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_byte, ctypes.c_byte, ctypes.c_bool) + self.pDll.Ort_Teach_Cmd.restype = self.check_error + + tag = self.pDll.Ort_Teach_Cmd(self.nSocket, type, direction, v, block) + + logger_.info(f'Ort_Teach_Cmd:{tag}') + + return tag + + def Teach_Stop_Cmd(self, block=True): + """ + Teach_Stop_Cmd 示教停止 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Teach_Stop_Cmd.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Teach_Stop_Cmd.restype = self.check_error + + tag = self.pDll.Teach_Stop_Cmd(self.nSocket, block) + + logger_.info(f'Teach_Stop_Cmd:{tag}') + + return tag + + def Set_Teach_Frame(self, type, block=True): + """ + Set_Teach_Frame 切换示教运动坐标系 + :param type: 0: 基座标运动, 1: 工具坐标系运动 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Teach_Frame.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int) + + self.pDll.Set_Teach_Frame.restype = self.check_error + + tag = self.pDll.Set_Teach_Frame(self.nSocket, type, block) + logger_.info(f'Set_Teach_Frame:{tag}') + + return tag + + def Get_Teach_Frame(self): + """ + Get_Teach_Frame 获取示教参考坐标系 + :return: type: 0: 基座标运动, 1: 工具坐标系运动 + """ + + self.pDll.Get_Teach_Frame.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + + self.pDll.Get_Teach_Frame.restype = self.check_error + + type = ctypes.c_int() + tag = self.pDll.Get_Teach_Frame(self.nSocket, ctypes.byref(type)) + logger_.info(f'Get_Teach_Frame:{tag}') + + return tag, type.value + + +class Set_controller(): + + def Get_Controller_State(self, retry=0): + """ + Get_Controller_State 获取控制器状态 + :return:电压,电流,温度 + """ + + self.pDll.Get_Controller_State.argtypes = ( + ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float + ), ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16)) + self.pDll.Get_Controller_State.restype = self.check_error + voltage = ctypes.c_float() + current = ctypes.c_float() + temperature = ctypes.c_float() + sys_err = ctypes.c_uint16() + + tag = self.pDll.Get_Controller_State(self.nSocket, ctypes.byref(voltage), ctypes.byref(current), + ctypes.byref(temperature + ), ctypes.byref(sys_err)) + + while tag and retry: + logger_.info(f'Get_Controller_State:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Controller_State(self.nSocket, ctypes.byref(voltage), ctypes.byref(current), + ctypes.byref(temperature + ), ctypes.byref(sys_err)) + + retry -= 1 + + return tag, voltage.value, current.value, temperature.value + + def Set_WiFi_AP_Data(self, wifi_name, password): + + """ + Set_WiFi_AP_Data 开启控制器WiFi AP模式设置 + :param wifi_name: 控制器wifi名称 + :param password: wifi密码 + :return: 返回值:0-成功,失败返回:错误码, rm_define.h查询. + 非阻塞模式,下发后,机械臂进入WIFI AP通讯模式 + """ + + self.pDll.Set_WiFi_AP_Data.argytypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p) + self.pDll.Set_WiFi_AP_Data.restype = self.check_error + + wifi_name = ctypes.c_char_p(wifi_name.encode('utf-8')) + password = ctypes.c_char_p(password.encode('utf-8')) + + tag = self.pDll.Set_WiFi_AP_Data(self.nSocket, wifi_name, password) + + logger_.info(f'Set_WiFi_AP_Data:{tag}') + + return tag + + def Set_WiFI_STA_Data(self, router_name, password): + + """ + Set_WiFI_STA_Data 控制器WiFi STA模式设置 + :param router_name: 路由器名称 + :param password: 路由器Wifi密码 + :return: 返回值:0-成功,失败返回:错误码, rm_define.h查询. + 非阻塞模式:设置成功后,机械臂进入WIFI STA通信模式 """ + + self.pDll.Set_WiFI_STA_Data.argytypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p) + self.pDll.Set_WiFI_STA_Data.restype = self.check_error + + router_name = ctypes.c_char_p(router_name.encode('utf-8')) + password = ctypes.c_char_p(password.encode('utf-8')) + + tag = self.pDll.Set_WiFI_STA_Data(self.nSocket, router_name, password) + + logger_.info(f'Set_WiFI_STA_Data:{tag}') + + return tag + + def Set_USB_Data(self, baudrate): + """ + Set_USB_Data 控制器UART_USB接口波特率设置 + + :param baudrate:波特率:9600,19200,38400,115200和460800,若用户设置其他数据,控制器会默认按照460800处理。 + :return: + """ + + self.pDll.Set_USB_Data.argtypes = (ctypes.c_int, ctypes.c_int) + self.pDll.Set_USB_Data.restype = self.check_error + + tag = self.pDll.Set_USB_Data(self.nSocket, baudrate) + + logger_.info(f'Set_USB_Data:{tag}') + + return tag + + def Set_RS485(self, baudrate): + """ + Set_RS485 控制器RS485接口波特率设置 + + :param baudrate:波特率:9600,19200,38400,115200和460800,若用户设置其他数据,控制器会默认按照460800处理。 + :return: + """ + + self.pDll.Set_RS485.argtypes = (ctypes.c_int, ctypes.c_int) + self.pDll.Set_RS485.restype = self.check_error + + tag = self.pDll.Set_RS485(self.nSocket, baudrate) + + logger_.info(f'Set_RS485:{tag}') + + return tag + + def Set_Arm_Power(self, cmd, block=True): + """ + Set_Arm_Power 设置机械臂电源 + param cmd true-上电, false-断电 + param block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + self.pDll.Set_Arm_Power.argtypes = (ctypes.c_int, ctypes.c_bool, ctypes.c_bool) + self.pDll.Set_Arm_Power.restype = self.check_error + + tag = self.pDll.Set_Arm_Power(self.nSocket, cmd, block) + + logger_.info(f'Set_Arm_Power:{tag}') + + return tag + + def Get_Arm_Power_State(self, retry=0): + """ + Get_Arm_Power_State 读取机械臂电源状态 + :return: + """ + + self.pDll.Get_Arm_Power_State.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + self.pDll.Get_Arm_Power_State.restype = self.check_error + + power = ctypes.c_int() + + tag = self.pDll.Get_Arm_Power_State(self.nSocket, ctypes.byref(power)) + + while tag and retry: + logger_.info(f'Get_Arm_Power_State:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Power_State(self.nSocket, ctypes.byref(power)) + + retry -= 1 + + return tag, power.value + + def Get_Arm_Software_Version(self, retry=0): + """ + Get_Arm_Software_Version 读取软件版本号 + :return:读取到的用户接口内核版本号,实时内核版本号,实时内核子核心1版本号,实时内核子核心2版本号,机械臂型号,仅I系列机械臂支持[-I] + + """ + + self.pDll.Get_Arm_Software_Version.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p, + ctypes.c_char_p, ctypes.c_char_p] + self.pDll.Get_Arm_Software_Version.restype = self.check_error + + # 创建字符串变量 + plan_version = ctypes.create_string_buffer(256) + ctrl_version = ctypes.create_string_buffer(256) + kernal1 = ctypes.create_string_buffer(256) + kernal2 = ctypes.create_string_buffer(256) + product_version = ctypes.create_string_buffer(256) # or None if not needed + + # 调用 Get_Arm_Software_Version 函数 + tag = self.pDll.Get_Arm_Software_Version(self.nSocket, plan_version, ctrl_version, kernal1, kernal2, + product_version) + + while tag and retry: + logger_.info(f'Get_Arm_Software_Version:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Arm_Software_Version(self.nSocket, plan_version, ctrl_version, kernal1, kernal2, + product_version) + + retry -= 1 + + return tag, plan_version.value.decode(), ctrl_version.value.decode(), kernal1.value.decode(), kernal2.value.decode(), product_version.value.decode() + + def Get_System_Runtime(self, retry=0): + """ + Get_System_Runtime 读取控制器的累计运行时间 + :param retry: + :return:读取结果,读取到的时间day,读取到的时间hour,读取到的时间min,读取到的时间sec + """ + + self.pDll.Get_System_Runtime.argtypes = ( + ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)) + self.pDll.Get_System_Runtime.restype = self.check_error + + day = ctypes.c_int() + hour = ctypes.c_int() + min = ctypes.c_int() + sec = ctypes.c_int() + + tag = self.pDll.Get_System_Runtime(self.nSocket, ctypes.byref(day), ctypes.byref(hour), + ctypes.byref(min), ctypes.byref(sec)) + + while tag and retry: + logger_.info(f'Get_System_Runtime:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_System_Runtime(self.nSocket, ctypes.byref(day), ctypes.byref(hour), + ctypes.byref(min), ctypes.byref(sec)) + + retry -= 1 + + return tag, day.value, hour.value, min.value, sec.value + + def Clear_System_Runtime(self, block=True): + + """ + Clear_System_Runtime 清零控制器的累计运行时间 + param block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_System_Runtime.argtypes = (ctypes.c_int, ctypes.c_bool) + + self.pDll.Clear_System_Runtime.restype = self.check_error + + tag = self.pDll.Clear_System_Runtime(self.nSocket, block) + + logger_.info(f'Clear_System_Runtime:{tag}') + + return tag + + def Get_Joint_Odom(self): + """ + Get_Joint_Odom 读取关节的累计转动角度 + :param retry: 如果失败一共尝试读取五次 + :return: + """ + if self.code == 6: + + self.pDll.Get_Joint_Odom.argtypes = (ctypes.c_int, ctypes.c_float * 6) + self.pDll.Get_Joint_Odom.restype = self.check_error + + odom = (ctypes.c_float * 6)() + + tag = self.pDll.Get_Joint_Odom(self.nSocket, odom) + + else: + self.pDll.Get_Joint_Odom.argtypes = (ctypes.c_int, ctypes.c_float * 7) + self.pDll.Get_Joint_Odom.restype = self.check_error + + odom = (ctypes.c_float * 7)() + + tag = self.pDll.Get_Joint_Odom(self.nSocket, odom) + + logger_.info(f'Get_Joint_Odom 关节的累计转动角度:{list(odom)}') + return tag, list(odom) + + def Clear_Joint_Odom(self, block=True): + + """ + Clear_Joint_Odom 清零关节的累计转动角度 + param block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_Joint_Odom.argtypes = (ctypes.c_int, ctypes.c_bool) + + self.pDll.Clear_Joint_Odom.restype = self.check_error + + tag = self.pDll.Clear_Joint_Odom(self.nSocket, block) + + logger_.info(f'Clear_Joint_Odom:{tag}') + + return tag + + def Set_High_Speed_Eth(self, num, block=True): + + """ + Set_High_Speed_Eth 设置高速网口 + :param num 0-关闭 1-开启 + param block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_High_Speed_Eth.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool) + + self.pDll.Set_High_Speed_Eth.restype = self.check_error + + tag = self.pDll.Set_High_Speed_Eth(self.nSocket, num, block) + + logger_.info(f'Set_High_Speed_Eth:{tag}') + + return tag + + def Set_High_Ethernet(self, ip, mask, gateway): + + """ + Set_High_Ethernet 设置高速网口网络配置[配置通讯内容] + :param ip: 网络地址 + :param mask: 子网掩码 + :param gateway: 网关 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Set_High_Ethernet.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p) + self.pDll.Set_High_Ethernet.restype = self.check_error + + ip = ctypes.c_char_p(ip.encode('utf-8')) + mask = ctypes.c_char_p(mask.encode('utf-8')) + gateway = ctypes.c_char_p(gateway.encode('utf-8')) + + tag = self.pDll.Set_High_Ethernet(self.nSocket, ip, mask, gateway) + + logger_.info(f'Set_High_Ethernet:{tag}') + + return tag + + def Get_High_Ethernet(self, retry=0): + + """ + Get_High_Ethernet 获取高速网口网络配置[配置通讯内容] + :param retry: 最大尝试次数 + :return: 成功返回 ip,mask,gateway,mac 否则None + + """ + + self.pDll.Get_High_Ethernet.argtypes = ( + ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p) + self.pDll.Get_High_Ethernet.restype = self.check_error + + ip = ctypes.create_string_buffer(255) + mask = ctypes.create_string_buffer(255) + gateway = ctypes.create_string_buffer(255) + mac = ctypes.create_string_buffer(255) + + tag = self.pDll.Get_High_Ethernet(self.nSocket, ip, mask, gateway, mac) + + while tag and retry: + logger_.info(f'Get_High_Ethernet:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_High_Ethernet(self.nSocket, ip, mask, gateway, mac) + + retry -= 1 + + return tag, ip.value.decode(), mask.value.decode(), gateway.value.decode(), mac.value.decode() + + def Save_Device_Info_All(self): + + """ + + Save_Device_Info_All 保存所有参数 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + tag = self.pDll.Save_Device_Info_All(self.nSocket) + logger_.info(f'Save_Device_Info_All:{tag}') + + return tag + + def Set_NetIP(self, ip): + + """ + Set_NetIP 配置有线网卡IP地址[-I] + :param ip:网络地址 + :return: + + """ + + self.pDll.Set_NetIP.argtypes = (ctypes.c_int, ctypes.c_char_p) + + ip = ctypes.c_char_p(ip.encode('utf-8')) + + tag = self.pDll.Set_NetIP(self.nSocket, ip) + + logger_.info(f'Set_NetIP:{tag}') + + return tag + + def Get_Wired_Net(self, retry=0): + """ + Get_Wired_Net 查询有线网卡网络信息[-I] + :param retry:接口调用失败后最多调用次数 + :return: ip,mask,gateway,mac + """ + + self.pDll.Get_Wired_Net.argtypes = ( + ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p) + self.pDll.Get_Wired_Net.restype = self.check_error + + ip = ctypes.create_string_buffer(255) + mask = ctypes.create_string_buffer(255) + mac = ctypes.create_string_buffer(255) + + tag = self.pDll.Get_Wired_Net(self.nSocket, ip, mask, mac) + + while tag and retry: + logger_.info(f'Get_Wired_Net:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Wired_Net(self.nSocket, ip, mask, mac) + + retry -= 1 + + return tag, ip.value.decode(), mask.value.decode(), mac.value.decode() + + def Get_Wifi_Net(self, retry=0): + """ + Get_Wifi_Net 查询无线网卡网络信息[-I] + :param retry:接口调用失败后最多调用次数 + :return: wifi_net + """ + + self.pDll.Get_Wifi_Net.argtypes = ( + ctypes.c_int, ctypes.POINTER(WiFi_Info)) + self.pDll.Get_Wifi_Net.restype = self.check_error + + wifi_net = WiFi_Info() + + tag = self.pDll.Get_Wifi_Net(self.nSocket, wifi_net) + + while tag and retry: + logger_.info(f'Get_Wifi_Net:{tag},retry is :{6 - retry}') + tag = self.pDll.Get_Wifi_Net(self.nSocket, wifi_net) + + retry -= 1 + + # if tag == 0: + # wifi_net = [wifi_net.ip, wifi_net.mac, wifi_net.mask, wifi_net.mode, wifi_net.password, wifi_net.ssid] + + return tag, wifi_net + + def Set_Net_Default(self): + + """ + Set_Net_Default 恢复网络出厂设置 + :return: + """ + + tag = self.pDll.Set_Net_Default(self.nSocket) + + logger_.info(f'Set_Net_Default:{tag}') + + return tag + + def Clear_System_Err(self, block=True): + """ + Clear_System_Err 清除系统错误代码 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_System_Err.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Clear_System_Err.restype = self.check_error + + tag = self.pDll.Clear_System_Err(self.nSocket, block) + logger_.info(f'Clear_System_Err:{tag}') + + return tag + + def Get_Arm_Software_Info(self): + """ + Get_Arm_Software_Info 读取机械臂软件信息[-I] + :return: software_info 机械臂软件信息 + """ + + self.pDll.Get_Arm_Software_Info.argtypes = ( + ctypes.c_int, ctypes.POINTER(ArmSoftwareInfo)) + self.pDll.Get_Arm_Software_Info.restype = self.check_error + + software_info = ArmSoftwareInfo() + + tag = self.pDll.Get_Arm_Software_Info(self.nSocket, software_info) + + return tag, software_info + + +class Set_IO(): + + def Set_IO_Mode(self, io_num, io_mode): + """ + 设置数字IO模式[-I] + :param io_num: IO端口号,范围:1~2 + :param io_mode: 模式,0-输入状态,1-输出状态,2-输入开始功能复用模式,3-输入暂停功能复用模式,4-输入继续功能复用模式,5-输入急停功能复用模式 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_IO_Mode.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_byte) + self.pDll.Set_IO_Mode.restype = self.check_error + + tag = self.pDll.Set_IO_Mode(self.nSocket, io_num, io_mode) + + logger_.info(f'Set_IO_Mode:{tag}') + + return tag + + def Set_DO_State(self, io_num, state, block=True): + """ + 设置数字IO输出 + :param io_num: 通道号,1~4 + :param state true-高, false-低 + :param block 0-非阻塞,发送后立即返回;1-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_DO_State.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_bool) + self.pDll.Set_DO_State.restype = self.check_error + + tag = self.pDll.Set_DO_State(self.nSocket, io_num, state, block) + logger_.info(f'Set_DO_State执行的结果:{tag}') + + return tag + + def Get_IO_State(self, num): + """ + Get_IO_State 获取IO状态 + :param num 通道号,1~4 + :return: state,mode + """ + + self.pDll.Get_IO_State.argtypes = ( + ctypes.c_int, ctypes.c_byte, ctypes.POINTER(ctypes.c_byte), ctypes.POINTER(ctypes.c_byte)) + + self.pDll.Get_IO_State.restype = self.check_error + + state = ctypes.c_byte() + mode = ctypes.c_byte() + tag = self.pDll.Get_IO_State(self.nSocket, num, ctypes.byref(state), ctypes.byref(mode)) + + logger_.info(f'Get_IO_State:{tag}') + + return tag, state.value, mode.value + + def Get_DO_State(self, io_num): + """ + Get_DO_State 查询数字IO输出状态(基础系列) + :param io_num 通道号,1~4 + :return: state mode指定数字IO通道返回的状态,1-高, 0-低 + """ + + self.pDll.Get_DO_State.argtypes = ( + ctypes.c_int, ctypes.c_byte, ctypes.POINTER(ctypes.c_byte)) + + self.pDll.Get_DO_State.restype = self.check_error + + state = ctypes.c_byte() + tag = self.pDll.Get_DO_State(self.nSocket, io_num, ctypes.byref(state)) + + logger_.info(f'Get_DO_State执行结果:{tag}') + + return tag, state.value + + def Get_DI_State(self, io_num): + """ + Get_DI_State 查询数字IO输入状态(基础系列) + :param io_num 通道号,1~3 + :return: state mode指定数字IO通道返回的状态,1-高, 0-低 + """ + + self.pDll.Get_DI_State.argtypes = ( + ctypes.c_int, ctypes.c_byte, ctypes.POINTER(ctypes.c_byte)) + + self.pDll.Get_DI_State.restype = self.check_error + + state = ctypes.c_byte() + tag = self.pDll.Get_DI_State(self.nSocket, io_num, ctypes.byref(state)) + + logger_.info(f'Get_DI_State执行结果:{tag}') + + return tag, state.value + + def Set_AO_State(self, io_num, voltage, block=True): + """ + 设置模拟IO输出(基础系列) + :param io_num: 通道号,1~4 + :param voltage: IO输出电压,分辨率0.001V,范围:0~10000,代表输出电压0v~10v + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_AO_State.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_float, ctypes.c_bool) + self.pDll.Set_AO_State.restype = self.check_error + + tag = self.pDll.Set_AO_State(self.nSocket, io_num, voltage, block) + + logger_.info(f'Set_AO_State执行结果:{tag}') + + return tag + + def Get_AO_State(self, io_num): + """ + Get_AO_State 查询数字IO输出状态(基础系列) + :param io_num 通道号,1~4 + :return: voltage IO输出电压,分辨率0.001V,范围:0~10000,代表输出电压0v~10v + """ + + self.pDll.Get_AO_State.argtypes = ( + ctypes.c_int, ctypes.c_byte, ctypes.POINTER(ctypes.c_byte)) + + self.pDll.Get_AO_State.restype = self.check_error + + voltage = ctypes.c_byte() + tag = self.pDll.Get_AO_State(self.nSocket, io_num, ctypes.byref(voltage)) + + logger_.info(f'Get_AO_State执行结果:{tag}') + + return tag, voltage.value + + def Get_AI_State(self, io_num): + """ + Get_AI_State 查询数字IO输入状态(基础系列) + :param io_num 通道号,1~4 + :return: voltage IO输出电压,分辨率0.001V,范围:0~10000,代表输出电压0v~10v + """ + + self.pDll.Get_AI_State.argtypes = ( + ctypes.c_int, ctypes.c_byte, ctypes.POINTER(ctypes.c_byte)) + + self.pDll.Get_AI_State.restype = self.check_error + + voltage = ctypes.c_byte() + tag = self.pDll.Get_AI_State(self.nSocket, io_num, ctypes.byref(voltage)) + + logger_.info(f'Get_AI_State执行结果:{tag}') + + return tag, voltage.value + + def Get_IO_Input(self): + """ + Get_IO_Input 查询所有数字和模拟IO的输入状态 + :return: + """ + + self.pDll.Get_IO_Input.argtypes = (ctypes.c_int, ctypes.c_int * 4, ctypes.c_float * 4) + self.pDll.Get_IO_Input.restype = self.check_error + + DI_state = (ctypes.c_int * 4)() + AI_voltage = (ctypes.c_float * 4)() + + tag = self.pDll.Get_IO_Input(self.nSocket, DI_state, AI_voltage) + + logger_.info(f'Get_IO_Input:{tag}') + + return tag, list(DI_state), list(AI_voltage) + + def Get_IO_Output(self): + """ + Get_IO_Output 查询所有数字和模拟IO的输出状态 + :return: + """ + + self.pDll.Get_IO_Output.argtypes = (ctypes.c_int, ctypes.c_int * 4, ctypes.c_float * 4) + self.pDll.Get_IO_Output.restype = self.check_error + + DO_state = (ctypes.c_int * 4)() + AO_voltage = (ctypes.c_float * 4)() + + tag = self.pDll.Get_IO_Output(self.nSocket, DO_state, AO_voltage) + + logger_.info(f'Get_IO_Output:{tag}') + + return tag, list(DO_state), list(AO_voltage) + + def Set_Voltage(self, voltage_type): + """ + Set_Voltage 设置电源输出[-I] + :param voltage_type: 电源输出类型,范围:0~3(0-0V,2-12V,3-24V) + :return: + """ + + self.pDll.Set_Voltage.argtypes = (ctypes.c_int, ctypes.c_byte) + + self.pDll.Set_Voltage.restype = self.check_error + + tag = self.pDll.Set_Voltage(self.nSocket, voltage_type) + + logger_.info(f'Set_Voltage:{tag}') + + return tag + + def Get_Voltage(self): + """ + Get_Voltage 获取电源输出类型[-I] + :return:电源输出类型,范围:0~3(0-0V,2-12V,3-24V) + """ + + self.pDll.Get_Voltage.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_byte)) + self.pDll.Get_Voltage.restype = self.check_error + + voltage_type = ctypes.c_byte() + + tag = self.pDll.Get_Voltage(self.nSocket, ctypes.byref(voltage_type)) + + logger_.info(f'Get_Voltage:{tag}') + + return voltage_type.value + + +class Set_Tool_IO(): + + def Set_Tool_DO_State(self, num, state, block=True): + """ + Set_Tool_DO_State 设置工具端数字IO输出 + :param num: 通道号,1~2 + :param state: true-高, false-低 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + + """ + + self.pDll.Set_Tool_DO_State.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool, ctypes.c_bool) + self.pDll.Set_Tool_DO_State.restypes = ctypes.c_int + + tag = self.pDll.Set_Tool_DO_State(self.nSocket, num, state, block) + + logger_.info(f'Set_Tool_DO_State:{tag}') + + return tag + + def Set_Tool_IO_Mode(self, num, state, block=True): + """ + Set_Tool_IO_Mode 设置数字IO模式输入 + :param num: 通道号,1~2 + :param state: 0输入, 1输出 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Tool_IO_Mode.argtypes = (ctypes.c_int, ctypes.c_byte, ctypes.c_bool, ctypes.c_bool) + self.pDll.Set_Tool_IO_Mode.restype = self.check_error + + tag = self.pDll.Set_Tool_IO_Mode(self.nSocket, num, state, block) + + logger_.info(f'Set_Tool_IO_Mode:{tag}') + + return tag + + def Get_Tool_IO_State(self): + """ + Get_Tool_IO_State 获取数字IO状态 + :param io_mode: 0-输入模式,1-输出模式 + :param io_state: 0-低,1-高 + :return: io_mode,io_state + """ + + self.pDll.Get_Tool_IO_State.argtypes = (ctypes.c_int, ctypes.c_float * 2, ctypes.c_float * 2) + self.pDll.Get_Tool_IO_State.restype = self.check_error + + io_mode = (ctypes.c_float * 2)() + io_state = (ctypes.c_float * 2)() + + tag = self.pDll.Get_Tool_IO_State(self.nSocket, io_mode, io_state) + + return tag, list(io_mode), list(io_state) + + def Set_Tool_Voltage(self, type, block=True): + """ + 打开夹抓 设置工具端电压输出 + param ArmSocket socket句柄 + type 电源输出类型,0-0V,1-5V,2-12V,3-24V + block True 阻塞 False 非阻塞 + return 0-成功,失败返回:错误码 + :return: + """ + self.pDll.Set_Tool_Voltage.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Tool_Voltage.restype = self.check_error + + tag = self.pDll.Set_Tool_Voltage(self.nSocket, type, block) + logger_.info(f'设置工作端电压输出结果:{tag}') + return tag + + def Get_Tool_Voltage(self): + """ + Get_Tool_Voltage 查询工具端电压输出 + :return:工具端电压输出 + """ + + self.pDll.Get_Tool_Voltage.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_byte)) + + voltage = ctypes.c_byte() + + tag = self.pDll.Get_Tool_Voltage(self.nSocket, ctypes.byref(voltage)) + + logger_.info(f'Get_Tool_Voltage:{tag}') + + return tag, voltage.value + + +class Set_Gripper(): + def Set_Gripper_Pick(self, speed, force, block=True, timeout=30): + """ + Set_Gripper_Pick_On 手爪力控夹取 + ArmSocket socket句柄 + speed 手爪夹取速度 ,范围 1~1000,无单位量纲 无 + force 力控阈值 ,范围 :50~1000,无单位量纲 无 + block True 阻塞 False 非阻塞 + timeout 超时时间设置,阻塞模式生效,单位:秒 + return 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Gripper_Pick.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_int) + self.pDll.Set_Gripper_Pick.restype = self.check_error + + tag = self.pDll.Set_Gripper_Pick(self.nSocket, speed, force, block, timeout) + logger_.info(f'Set_Gripper_Pick执行结果:{tag}') + + return tag + + def Set_Gripper_Release(self, speed, block=True, timeout=30): + """ + Set_Gripper_Release 手爪松开 + ArmSocket socket句柄 + speed 手爪松开速度 ,范围 1~1000,无单位量纲 + block True 阻塞 False 非阻塞 + timeout 超时时间设置,阻塞模式生效,单位:秒 + return 0-成功,失败返回:错误码 +- + """ + + self.pDll.Set_Gripper_Release.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_int) + self.pDll.Set_Gripper_Release.restype = self.check_error + + tag = self.pDll.Set_Gripper_Release(self.nSocket, speed, block, timeout) + logger_.info(f'Set_Gripper_Release执行结果:{tag}') + return tag + + def Set_Gripper_Route(self, min_limit, max_limit, block=True): + """ + Set_Gripper_Route 设置手爪行程 + :param min_limit: 手爪最小开口,范围 :0~1000,无单位量纲 无 + :param max_limit: 手爪最大开口,范围 :0~1000,无单位量纲 无 + :param block: block RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询 + """ + + self.pDll.Set_Gripper_Route.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Gripper_Route.restype = self.check_error + + tag = self.pDll.Set_Gripper_Route(self.nSocket, min_limit, max_limit, block) + + logger_.info(f'Set_Gripper_Route:{tag}') + + return tag + + def Set_Gripper_Pick_On(self, speed, force, block=True, timeout=30): + """ + Set_Gripper_Pick_On 手爪力控持续夹取 + :param speed:手爪夹取速度 ,范围 1~1000,无单位量纲 无 + :param force:力控阈值 ,范围 :50~1000,无单位量纲 无 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :param timeout:超时时间设置,阻塞模式生效,单位:秒 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Gripper_Pick_On.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_int) + self.pDll.Set_Gripper_Pick_On.restype = self.check_error + + tag = self.pDll.Set_Gripper_Pick_On(self.nSocket, speed, force, block, timeout) + + logger_.info(f'Set_Gripper_Pick_On:{tag}') + + return tag + + def Set_Gripper_Position(self, position, block=True, timeout=30): + """ + Set_Gripper_Position 设置手爪开口度 + :param position:手爪开口位置 ,范围 :1~1000,无单位量纲 无 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :param timeout:超时时间设置,阻塞模式生效,单位:秒 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Gripper_Position.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_int) + self.pDll.Set_Gripper_Position.restype = self.check_error + + tag = self.pDll.Set_Gripper_Position(self.nSocket, position, block, timeout) + + logger_.info(f'Set_Gripper_Position:{tag}') + + return tag + + def Get_Gripper_State(self): + """ + Get_Gripper_State 获取夹爪状态 + :return:gripper_state 夹爪状态 + """ + + self.pDll.Get_Gripper_State.argtypes = (ctypes.c_int, ctypes.POINTER(GripperState)) + self.pDll.Get_Gripper_State.restype = self.check_error + + state = GripperState() + tag = self.pDll.Get_Gripper_State(self.nSocket, ctypes.byref(state)) + logger_.info(f'Get_Gripper_State:{tag}') + + return tag, state + + +class Drag_Teach(): + def Start_Drag_Teach(self, block=True): + """ + Start_Drag_Teach 开始控制机械臂进入拖动示教模式 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Start_Drag_Teach.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Start_Drag_Teach.restype = self.check_error + + tag = self.pDll.Start_Drag_Teach(self.nSocket, block) + + logger_.info(f'Start_Drag_Teach:{tag}') + + return tag + + def Stop_Drag_Teach(self, block=True): + """ + Stop_Drag_Teach 控制机械臂退出拖动示教模式 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Stop_Drag_Teach.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Stop_Drag_Teach.restype = self.check_error + + tag = self.pDll.Stop_Drag_Teach(self.nSocket, block) + + logger_.info(f'Stop_Drag_Teach:{tag}') + + return tag + + def Run_Drag_Trajectory(self, block=True): + """ + Run_Drag_Trajectory 控制机械臂复现拖动示教的轨迹,必须在拖动示教结束后才能使用, + 同时保证机械臂位于拖动示教的起点位置。 + 若当前位置没有位于轨迹复现起点,请先调用Drag_Trajectory_Origin,否则会返回报错信息。 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Run_Drag_Trajectory.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Run_Drag_Trajectory.restype = self.check_error + + tag = self.pDll.Run_Drag_Trajectory(self.nSocket, block) + + logger_.info(f'Run_Drag_Trajectory:{tag}') + + return tag + + def Pause_Drag_Trajectory(self, block=True): + """ + Pause_Drag_Trajectory 控制机械臂在轨迹复现过程中的暂停 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Pause_Drag_Trajectory.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Pause_Drag_Trajectory.restype = self.check_error + + tag = self.pDll.Pause_Drag_Trajectory(self.nSocket, block) + + logger_.info(f'Pause_Drag_Trajectory:{tag}') + + return tag + + def Continue_Drag_Trajectory(self, block=True): + """ + Continue_Drag_Trajectory 控制机械臂在轨迹复现过程中暂停之后的继续, + 轨迹继续时,必须保证机械臂位于暂停时的位置, + 否则会报错,用户只能从开始位置重新复现轨迹。 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Continue_Drag_Trajectory.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Continue_Drag_Trajectory.restype = self.check_error + + tag = self.pDll.Continue_Drag_Trajectory(self.nSocket, block) + + logger_.info(f'Continue_Drag_Trajectory:{tag}') + + return tag + + def Stop_Drag_Trajectory(self, block=True): + """ + Stop_Drag_Trajectory 控制机械臂在轨迹复现过程中的停止 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Stop_Drag_Trajectory.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Stop_Drag_Trajectory.restype = self.check_error + + tag = self.pDll.Stop_Drag_Trajectory(self.nSocket, block) + + logger_.info(f'Stop_Drag_Trajectory:{tag}') + + return tag + + def Drag_Trajectory_Origin(self, block=True): + """ + Drag_Trajectory_Origin 轨迹复现前,必须控制机械臂运动到轨迹起点, + 如果设置正确,机械臂将以20%的速度运动到轨迹起点 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Drag_Trajectory_Origin.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Drag_Trajectory_Origin.restype = self.check_error + + tag = self.pDll.Drag_Trajectory_Origin(self.nSocket, block) + + logger_.info(f'Drag_Trajectory_Origin:{tag}') + + return tag + + def Start_Multi_Drag_Teach(self, mode, singular_wall, block=True): + """ + Start_Multi_Drag_Teach 开始复合模式拖动示教 + :param mode: 拖动示教模式 0-电流环模式,1-使用末端六维力,只动位置,2-使用末端六维力 ,只动姿态, 3-使用末端六维力,位置和姿态同时动 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: + """ + + self.pDll.Start_Multi_Drag_Teach.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Start_Multi_Drag_Teach.restype = self.check_error + + tag = self.pDll.Start_Multi_Drag_Teach(self.nSocket, mode, singular_wall, block) + logger_.info(f'Start_Multi_Drag_Teach:{tag}') + + return tag + + def Set_Force_Postion(self, sensor, mode, direction, N, block=True): + """ + Set_Force_Postion 力位混合控制 + :param sensor: 0-一维力;1-六维力 + :param mode: 0-基坐标系力控;1-工具坐标系力控; + :param direction: 力控方向;0-沿X轴;1-沿Y轴;2-沿Z轴;3-沿RX姿态方向;4-沿RY姿态方向;5-沿RZ姿态方向 + :param N: 力的大小,单位N + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Force_Postion.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Force_Postion.restype = self.check_error + + tag = self.pDll.Set_Force_Postion(self.nSocket, sensor, mode, direction, N, block) + + logger_.info(f'Set_Force_Postion:{tag}') + + return tag + + def Stop_Force_Postion(self, block=True): + """ + Stop_Force_Postion 结束力位混合控制 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Stop_Force_Postion.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Stop_Force_Postion.restype = self.check_error + + tag = self.pDll.Stop_Force_Postion(self.nSocket, block) + + logger_.info(f'Stop_Force_Postion:{tag}') + + return tag + + def Save_Trajectory(self, file_name): + """ + Save_Trajectory 获取刚拖动过的轨迹,在拖动示教后调用 + :param filename: 轨迹要保存路径及名称,例: c:/rm_test.txt + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Save_Trajectory.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int)) + self.pDll.Save_Trajectory.restype = self.check_error + file_name = ctypes.create_string_buffer(file_name.encode('utf-8')) + num = ctypes.c_int() + tag = self.pDll.Save_Trajectory(self.nSocket, file_name, ctypes.byref(num)) + time.sleep(1) + logger_.info(f'Save_Trajectory:{tag}') + + return tag, num.value + + +class Six_Force(): + def Get_Force_Data(self): + """ + Get_Force_Data 查询当前六维力传感器得到的力和力矩信息,若要周期获取力数据 周期不能小于50ms。 + :return:力和力矩信息 + """ + + self.pDll.Get_Force_Data.argtypes = (ctypes.c_int, ctypes.c_float * 6, ctypes.c_float * 6 + , ctypes.c_float * 6, ctypes.c_float * 6) + + self.pDll.Get_Force_Data.restype = self.check_error + + force = (ctypes.c_float * 6)() + zero_force = (ctypes.c_float * 6)() + work_zero = (ctypes.c_float * 6)() + tool_zero = (ctypes.c_float * 6)() + + tag = self.pDll.Get_Force_Data(self.nSocket, force, zero_force, work_zero, tool_zero) + + logger_.info(f'Get_Force_Data:{tag}') + + return tag, list(force), list(zero_force), list(work_zero), list(tool_zero) + + def Set_Force_Sensor(self): + + tag = self.pDll.Set_Force_Sensor(self.nSocket) + logger_.info(f'Set_Force_Sensor:{tag}') + + return tag + + def Manual_Set_Force(self, type, joints): + """ + Manual_Set_Force 手动设置六维力重心参数,六维力重新安装后,必须重新计算六维力所收到的初始力和重心。 + :param type: 点位;1~4,调用此函数四次 + :param joints: 关节角度 + :return: + """ + + if self.code == 6: + self.pDll.Manual_Set_Force.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_float * 6) + self.pDll.Manual_Set_Force.restype = self.check_error + + joints = (ctypes.c_float * 6)(*joints) + + tag = self.pDll.Manual_Set_Force(self.nSocket, type, joints) + + else: + self.pDll.Manual_Set_Force.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_float * 7) + self.pDll.Manual_Set_Force.restype = self.check_error + + joints = (ctypes.c_float * 7)(*joints) + + tag = self.pDll.Manual_Set_Force(self.nSocket, type, joints) + + logger_.info(f'Manual_Set_Force:{tag}') + return tag + + def Stop_Set_Force_Sensor(self, block=True): + + """ + Stop_Set_Force_Sensor 在标定六/一维力过程中,如果发生意外,发送该指令,停止机械臂运动,退出标定流程 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Stop_Set_Force_Sensor.argtypes = [ctypes.c_int, ctypes.c_bool] + self.pDll.Stop_Set_Force_Sensor.restype = self.check_error + + tag = self.pDll.Stop_Set_Force_Sensor(self.nSocket, block) + + logger_.info(f'Stop_Set_Force_Sensor:{tag}') + + return tag + + def Clear_Force_Data(self, block=True): + + """ + Clear_Force_Data 将六维力数据清零,即后续获得的所有数据都是基于当前数据的偏移量 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回; RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_Force_Data.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Clear_Force_Data.restype = self.check_error + + tag = self.pDll.Clear_Force_Data(self.nSocket, block) + + logger_.info(f'Clear_Force_Data:{tag}') + + return tag + + +class Set_Hand(): + + def Set_Hand_Seq(self, seq_num, block=1): + """ + 设置灵巧手目标动作序列 + """ + tag = self.pDll.Set_Hand_Seq(self.nSocket, seq_num, block) + logger_.info(f'Set_Hand_Seq:{tag}') + time.sleep(0.5) + + return tag + + def Set_Hand_Posture(self, posture_num, block=1): + """ + 设置灵巧手目标手势 + """ + tag = self.pDll.Set_Hand_Posture(self.nSocket, posture_num, block) + logger_.info(f'Set_Hand_Posture:{tag}') + time.sleep(1) + + return tag + + def Set_Hand_Angle(self, angle, block=True): + """ + Set_Hand_Angle 设置灵巧手各关节角度 + :param angle:手指角度数组,6个元素分别代表6个自由度的角度。范围:0~1000.另外,-1代表该自由度不执行任何操作,保持当前状态 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + + """ + + self.pDll.Set_Hand_Angle.argtypes = (ctypes.c_int, ctypes.c_int * 6, ctypes.c_bool) + self.pDll.Set_Hand_Angle.restype = self.check_error + + angle = (ctypes.c_int * 6)(*angle) + + tag = self.pDll.Set_Hand_Angle(self.nSocket, angle, block) + + logger_.info(f'Set_Hand_Angle:{tag}') + + return tag + + def Set_Hand_Speed(self, speed, block=True): + """ + Set_Hand_Speed 设置灵巧手各关节速度 + :param speed:灵巧手各关节速度设置,范围:1~1000 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Hand_Speed.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Hand_Speed.restype = self.check_error + + tag = self.pDll.Set_Hand_Speed(self.nSocket, speed, block) + + logger_.info(f'Set_Hand_Speed:{tag}') + + return tag + + def Set_Hand_Force(self, force, block=True): + """ + Set_Hand_Force 设置灵巧手各关节力阈值 + :param force 灵巧手各关节力阈值设置,范围:1~1000,代表各关节的力矩阈值(四指握力0~10N,拇指握力0~15N)。 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Hand_Force.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Hand_Force.restype = self.check_error + + tag = self.pDll.Set_Hand_Force(self.nSocket, force, block) + + logger_.info(f'Set_Hand_Force:{tag}') + + return tag + + +class one_force(): + def Get_Fz(self): + """ + Get_Fz 该函数用于查询末端一维力数据 + :return:末端一维力数据 + """ + + self.pDll.Get_Fz.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)) + self.pDll.Get_Fz.restype = self.check_error + + fz = ctypes.c_float() + zero_fz = ctypes.c_float() + work_fz = ctypes.c_float() + tool_fz = ctypes.c_float() + + tag = self.pDll.Get_Fz(self.nSocket, ctypes.byref(fz), ctypes.byref(zero_fz), ctypes.byref(work_fz), + ctypes.byref(tool_fz)) + + logger_.info(f'Get_Fz:{tag}') + + return tag, fz.value, zero_fz.value, work_fz.value, tool_fz.value + + def Clear_Fz(self, block=True): + """ + Clear_Fz 该函数用于清零末端一维力数据 + :param block:RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Clear_Fz.argtypes = (ctypes.c_int, ctypes.c_bool) + self.pDll.Clear_Fz.restype = self.check_error + + tag = self.pDll.Clear_Fz(self.nSocket, block) + + logger_.info(f'Clear_Fz:{tag}') + + return tag + + def Auto_Set_Fz(self): + """ + Auto_Set_Fz 该函数用于自动一维力数据 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + tag = self.pDll.Auto_Set_Fz(self.nSocket) + logger_.info(f'Auto_Set_Fz:{tag}') + + return tag + + def Manual_Set_Fz(self, joint1, joint2): + """ + Manual_Set_Fz 该函数用于手动设置一维力数据 + :param joint1: + :param joint2: + :return: + """ + + le = self.code + + self.pDll.Manual_Set_Fz.argtypes = (ctypes.c_int, ctypes.c_float * le, ctypes.c_float * le) + self.pDll.Manual_Set_Fz.restype = self.check_error + + joint1 = (ctypes.c_float * le)(*joint1) + joint2 = (ctypes.c_float * le)(*joint2) + + tag = self.pDll.Manual_Set_Fz(self.nSocket, joint1, joint2) + + logger_.info(f'Manual_Set_Fz:{tag}') + + return tag + + +class ModbusRTU(): + def Set_Modbus_Mode(self, port, baudrate, timeout, block=True): + """ + 配置通讯端口 Modbus RTU 模式 + :param port:通讯端口,0-控制器RS485端口为RTU主站,1-末端接口板RS485接口为RTU从站,2-控制器RS485端口为RTU从站 + :param baudrate:波特率,支持 9600,115200,460800 三种常见波特率 + :param timeout:超时时间,单位百毫秒。 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Set_Modbus_Mode.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Set_Modbus_Mode.restype = self.check_error + + tag = self.pDll.Set_Modbus_Mode(self.nSocket, port, baudrate, timeout, block) + + logger_.info(f'Set_Modbus_Mode:{tag}') + + return tag + + def Close_Modbus_Mode(self, port, block=True): + """ + Close_Modbus_Mode 关闭通讯端口 Modbus RTU 模式 + + :param port: 通讯端口,0-控制器RS485端口为RTU主站,1-末端接口板RS485接口为RTU从站,2-控制器RS485端口为RTU从站 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Close_Modbus_Mode.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Close_Modbus_Mode.restypes = ctypes.c_int + + tag = self.pDll.Close_Modbus_Mode(self.nSocket, port, block) + + logger_.info(f'Close_Modbus_Mode:{tag}') + + return tag + + def Set_Modbustcp_Mode(self, ip, port, timeout): + """ + Set_Modbustcp_Mode配置连接 ModbusTCP 从站--I系列 + :param ip: 从机IP地址 + :param port: 端口号 + :param timeout: 超时时间,单位秒。 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + self.pDll.Set_Modbustcp_Mode.argtypes = (ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_int) + self.pDll.Set_Modbustcp_Mode.restype = self.check_error + + ip = ctypes.c_char_p(ip.encode('utf-8')) + tag = self.pDll.Set_Modbustcp_Mode(self.nSocket, ip, port, timeout) + + logger_.info(f'Set_Modbustcp_Mode:{tag}') + + return tag + + def Close_Modbustcp_Mode(self): + """ + Close_Modbustcp_Mode 配置关闭 ModbusTCP 从站--I系列 + :return: 0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Close_Modbustcp_Mode.argtype = ctypes.c_int + self.pDll.Close_Modbustcp_Mode.restypes = ctypes.c_int + + tag = self.pDll.Close_Modbustcp_Mode(self.nSocket) + + logger_.info(f'Close_Modbustcp_Mode:{tag}') + + return tag + + def Get_Read_Coils(self, port, address, num, device): + """ + Get_Read_Coils 读线圈 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param num:要读的线圈的数量,该指令最多一次性支持读 8 个线圈数据,即返回的数据不会一个字节 + :param device:外设设备地址 + :return:返回离散量 + """ + + self.pDll.Get_Read_Coils.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + self.pDll.Get_Read_Coils.restype = self.check_error + + coils_data = ctypes.c_int() + + tag = self.pDll.Get_Read_Coils(self.nSocket, port, address, num, device, ctypes.byref(coils_data)) + + return tag, coils_data.value + + def Get_Read_Input_Status(self, port, address, num, device): + """ + Get_Read_Input_Status 读离散量输入 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param num:要读的线圈的数量,该指令最多一次性支持读 8 个线圈数据,即返回的数据不会一个字节 + :param device:外设设备地址 + :return:返回离散量 + """ + + self.pDll.Get_Read_Input_Status.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + self.pDll.Get_Read_Input_Status.restype = self.check_error + + coils_data = ctypes.c_int() + + tag = self.pDll.Get_Read_Input_Status(self.nSocket, port, address, num, device, ctypes.byref(coils_data)) + + return tag, coils_data.value + + def Get_Read_Holding_Registers(self, port, address, device): + """ + Get_Read_Holding_Registers 读保持寄存器 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param device: 外设设备地址 + :return:返回离散量 + """ + + self.pDll.Get_Read_Holding_Registers.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + + self.pDll.Get_Read_Holding_Registers.restype = self.check_error + + coils_data = ctypes.c_int() + + tag = self.pDll.Get_Read_Holding_Registers(self.nSocket, port, address, device, ctypes.byref(coils_data)) + + return tag, coils_data.value + + def Get_Read_Input_Registers(self, port, address, device): + """ + Get_Read_Input_Registers 读输入寄存器 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param device: 外设设备地址 + :return:返回离散量 + """ + + self.pDll.Get_Read_Input_Registers.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int)) + + self.pDll.Get_Read_Input_Registers.restype = self.check_error + + coils_data = ctypes.c_int() + + tag = self.pDll.Get_Read_Input_Registers(self.nSocket, port, address, device, ctypes.byref(coils_data)) + + return tag, coils_data.value + + def Write_Single_Coil(self, port, address, data, device, block=True): + """ + Write_Single_Coil 写单圈数据 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param data: 要读的线圈的数量,该指令最多一次性支持读 8 个线圈数据,即返回的数据不会一个字节 + :param device: 外设设备地址 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Write_Single_Coil.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Write_Single_Coil.restype = self.check_error + + tag = self.pDll.Write_Single_Coil(self.nSocket, port, address, data, device, block) + + logger_.info(f'Write_Single_Coil:{tag}') + + return tag + + def Write_Coils(self, port, address, num, coils_data, device, block=True): + """ + brief Write_Coils 写多圈数据 + param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + param address: 线圈起始地址 + param num: 写线圈个数,每次写的数量不超过160个 + param coils_data: 要写入线圈的数据数组,类型:byte。若线圈个数不大于8,则写入的数据为1个字节;否则,则为多个数据的数组 + param device: 外设设备地址 + param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + return + """ + device_num = int(num // 8 + 1) + self.pDll.Write_Coils.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_byte * device_num, ctypes.c_int, + ctypes.c_bool) + self.pDll.Write_Coils.restype = self.check_error + + coils_data = (ctypes.c_byte * device_num)(*coils_data) + + tag = self.pDll.Write_Coils(self.nSocket, port, address, num, coils_data, device, block) + + logger_.info(f'Write_Coils:{tag}') + + return tag + + def Write_Single_Register(self, port, address, data, device, block=True): + """ + Write_Single_Register 写单个寄存器 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 线圈起始地址 + :param data: 要读的线圈的数量,该指令最多一次性支持读 8 个线圈数据,即返回的数据不会一个字节 + :param device: 外设设备地址 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + + self.pDll.Write_Single_Register.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_bool) + self.pDll.Write_Single_Register.restype = self.check_error + + tag = self.pDll.Write_Single_Register(self.nSocket, port, address, data, device, block) + + logger_.info(f'Write_Single_Register:{tag}') + + return tag + + def Write_Registers(self, port, address, num, single_data, device, block=True): + """ + Write_Registers 写多个寄存器 + :param port: 通讯端口,0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 寄存器起始地址 + :param num: 写寄存器个数,寄存器每次写的数量不超过10个 + :param single_data: 要写入寄存器的数据数组,类型:byte + :param device: 外设设备地址 + :param block: RM_NONBLOCK-非阻塞,发送后立即返回;RM_BLOCK-阻塞,等待控制器返回设置成功指令 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + single_data_num = int(num * 2) + + self.pDll.Write_Registers.argtypes = ( + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_byte * single_data_num, ctypes.c_int, + ctypes.c_bool) + self.pDll.Write_Registers.restype = self.check_error + + single_data = (ctypes.c_byte * single_data_num)(*single_data) + + tag = self.pDll.Write_Registers(self.nSocket, port, address, num, single_data, device, block) + + logger_.info(f'Write_Registers:{tag}') + + return tag + + def Read_Multiple_Holding_Registers(self, port, address, num, device): + """ + Read_Multiple_Holding_Registers 读多个保存寄存器 + :param port: 0-控制器 RS485 端口,1-末端接口板 RS485 接口,3-控制器 ModbusTCP 设备 + :param address: 寄存器起始地址 + :param num: 2= 10: + self.code //= 10 + + if pCallback is None: + self.pDll.RM_API_Init(dev_mode, 0) # API初始化 + else: + self.pDll.RM_API_Init(dev_mode, pCallback) # API初始化 + + logger_.info('开始进行机械臂API初始化完毕') + + self.API_Version() + self.Algo_Version() + + # 连接机械臂 + byteIP = bytes(ip, "gbk") + self.nSocket = self.pDll.Arm_Socket_Start(byteIP, 8080, 200) # 连接机械臂 + + state = self.pDll.Arm_Socket_State(self.nSocket) # 查询机械臂连接状态 + + if state: + logger_.info(f'连接机械臂连接失败:{state}') + + else: + logger_.info(f'连接机械臂成功,句柄为:{self.nSocket}') + + def Arm_Socket_State(self): + """ + Arm_Socket_State 查询机械臂连接状态 + :return:0-成功,失败返回:错误码, rm_define.h查询. + """ + state = self.pDll.Arm_Socket_State(self.nSocket) # 查询机械臂连接状态 + + if state == 0: + return state + else: + return errro_message[state] + + def API_Version(self): + """ + API_Version 查询API版本信息 + return API版本号 + """ + self.pDll.API_Version.restype = ctypes.c_char_p + api_name = self.pDll.API_Version() + logger_.info(f'API_Version:{api_name.decode()}') + time.sleep(0.5) + + return api_name.decode() + + def Algo_Version(self): + """ + API_Version 查询API版本信息 + return API版本号 + """ + self.pDll.Algo_Version.restype = ctypes.c_char_p + api_name = self.pDll.Algo_Version() + logger_.info(f'Algo_Version:{api_name.decode()}') + time.sleep(0.5) + + return api_name.decode() + + def RM_API_UnInit(self): + + """ + API反初始化 释放资源 + :return: + """ + tag = self.pDll.RM_API_UnInit() + logger_.info(f'API反初始化 释放资源') + return tag + + def Set_Arm_Run_Mode(self, mode): + """ + 设置机械臂模式(仿真/真实) + mode 模式 0:仿真 1:真实 + """ + self.pDll.Set_Arm_Run_Mode.argtypes = [ctypes.c_int, ctypes.c_int] + self.pDll.Set_Arm_Run_Mode.restype = self.check_error + + result = self.pDll.Set_Arm_Run_Mode(self.nSocket, mode) + logger_.info(f'Set_Arm_Run_Mode:{result}') + + return result + + def Get_Arm_Run_Mode(self): + """ + 获取机械臂模式(仿真/真实) + mode 模式 0:仿真 1:真实 + """ + self.pDll.Get_Arm_Run_Mode.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_int)] + self.pDll.Get_Arm_Run_Mode.restype = self.check_error + + mode = ctypes.c_int() + result = self.pDll.Get_Arm_Run_Mode(self.nSocket, ctypes.byref(mode)) + logger_.info(f'Get_Arm_Run_Mode:{result}') + + return result, mode.value + + def Arm_Socket_Close(self): + + """ + 关闭与机械臂的Socket连接 + :return: + """ + self.pDll.Arm_Socket_Close(self.nSocket) + logger_.info(f'关闭与机械臂的Socket连接') + + @staticmethod + def check_error(tag): + + if tag == 0: + return tag + else: + return errro_message[tag] diff --git a/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/servo_robotic_arm.py b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/servo_robotic_arm.py new file mode 100644 index 000000000..9941e28d7 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/src/shadow_rm_robot/servo_robotic_arm.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import time +import yaml +import serial +import logging +import binascii +import numpy as np + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class ServoArm: + def __init__(self, config_file="config.yaml"): + """初始化机械臂的串口连接并发送初始数据。 + + Args: + config_file (str): 配置文件的路径。 + """ + self.config = self._load_config(config_file) + self.port = self.config["port"] + self.baudrate = self.config["baudrate"] + self.hex_data = self.config["hex_data"] + self.arm_axis = self.config.get("arm_axis", 7) + + self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0) + + self.bytes_to_send = binascii.unhexlify(self.hex_data.replace(" ", "")) + self.serial_conn.write(self.bytes_to_send) + time.sleep(1) + + def _load_config(self, config_file): + """加载配置文件。 + + Args: + config_file (str): 配置文件的路径。 + + Returns: + dict: 配置文件内容。 + """ + with open(config_file, "r") as file: + config = yaml.safe_load(file) + return config + + def _bytes_to_signed_int(self, byte_data): + """将字节数据转换为有符号整数。 + + Args: + byte_data (bytes): 字节数据。 + + Returns: + int: 有符号整数。 + """ + return int.from_bytes(byte_data, byteorder="little", signed=True) + + def _parse_joint_data(self, hex_received): + """解析接收到的十六进制数据并提取关节数据。 + + Args: + hex_received (str): 接收到的十六进制字符串数据。 + + Returns: + dict: 解析后的关节数据。 + """ + logging.debug(f"hex_received: {hex_received}") + joints = {} + for i in range(self.arm_axis): + start = 14 + i * 10 + end = start + 8 + joint_hex = hex_received[start:end] + joint_byte_data = bytearray.fromhex(joint_hex) + joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0 + joints[f"joint_{i+1}"] = joint_value + grasp_start = 14 + self.arm_axis*10 + grasp_hex = hex_received[grasp_start:grasp_start+8] + grasp_byte_data = bytearray.fromhex(grasp_hex) + # 夹爪进行归一化处理 + grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000 + + joints["grasp"] = grasp_value + return joints + + def get_joint_actions(self): + """从串口读取数据并解析关节动作。 + + Returns: + dict: 包含关节数据的字典。 + """ + self.serial_conn.write(self.bytes_to_send) + bytes_received = self.serial_conn.read(self.serial_conn.inWaiting()) + hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper() + actions = self._parse_joint_data(hex_received) + return actions + def set_gripper_action(self, action): + """设置夹爪动作。 + + Args: + action (int): 夹爪动作值。 + """ + action = int(action * 1000) + action_bytes = action.to_bytes(4, byteorder="little", signed=True) + self.bytes_to_send = self.bytes_to_send[:74] + action_bytes + self.bytes_to_send[78:] + +if __name__ == "__main__": + servo_arm = ServoArm("/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml") + while True: + joint_actions = servo_arm.get_joint_actions() + logging.info(joint_actions) + time.sleep(1) \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/test/test_modbus.py b/realman_src/realman_aloha/shadow_rm_robot/test/test_modbus.py new file mode 100644 index 000000000..52ea73aaa --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/test/test_modbus.py @@ -0,0 +1,26 @@ +import pytest +import json +import yaml +import numpy as np +from shadow_rm_robot.realman_arm import RmArm +import logging + +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +arm = RmArm("./data/rm_arm.yaml") + +arm.arm.send( '{"command":"set_modbus_mode","port":0,"baudrate":115200,"timeout ":2}\r\n'.encode("utf-8")) + +# arm.arm.send( '{"command":"close_modbus_mode","port":1}\r\n'.encode("utf-8")) + +a = arm.arm.recv(1024) + +logging.debug(a) + +arm.arm.send( '{"command":"read_holding_registers","port":1,"address":14,"device":2}\r\n'.encode("utf-8")) + +b = arm.arm.recv(1024) +logging.debug(b) \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/test/test_realman_arm.py b/realman_src/realman_aloha/shadow_rm_robot/test/test_realman_arm.py new file mode 100644 index 000000000..792d9b71d --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/test/test_realman_arm.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import patch, MagicMock +import json +import yaml +import numpy as np +from shadow_rm_robot.realman_arm import RmArm + +class TestRmArm: + @pytest.fixture(autouse=True) + def setup_method(self, tmpdir): + # 模拟配置文件 + self.config_data = { + "arm_ip": "192.168.1.18", + "arm_port": 8080 + } + self.config_file = tmpdir.join("test_config.yaml") + with open(self.config_file, "w") as file: + yaml.dump(self.config_data, file) + + # 初始化 RmArm 对象 + self.rm_arm = RmArm(self.config_file) + + @patch("socket.socket") + def test_initialization(self, mock_socket): + # 测试初始化 + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + + rm_arm = RmArm(self.config_file) + assert rm_arm.arm_ip == self.config_data["arm_ip"] + assert rm_arm.arm_port == self.config_data["arm_port"] + + # 检查网络连接初始化 + mock_socket_instance.connect.assert_called_with((self.config_data["arm_ip"], self.config_data["arm_port"])) + + def test_json_to_numpy(self): + # 测试 JSON 数据解析为 NumPy 数组 + json_data = json.dumps({"joint": [1, 2, 3, 4, 5, 6]}) + byte_data = json_data.encode('utf-8') + result = self.rm_arm._json_to_numpy(byte_data, 'joint') + expected_result = np.array([1, 2, 3, 4, 5, 6], dtype=float) + np.testing.assert_array_equal(result, expected_result) + + # 测试键不存在的情况 + json_data = json.dumps({"other_key": [1, 2, 3]}) + byte_data = json_data.encode('utf-8') + result = self.rm_arm._json_to_numpy(byte_data, 'joint') + expected_result = np.array([]) + np.testing.assert_array_equal(result, expected_result) + + def test_generate_command(self): + # 测试生成关节命令 + data = np.array([0.1, 0.2, 0.3]) + cmd_type = 'joint' + result = self.rm_arm._generate_command(data, cmd_type) + expected_result = json.dumps({"command": "movej", "joint": [100, 200, 300], "v": 40, "r": 0}) + '\r\n' + assert result == expected_result + + # 测试生成夹爪命令 + data = np.array([500]) + cmd_type = 'gripper' + result = self.rm_arm._generate_command(data, cmd_type) + expected_result = json.dumps({"command": "set_gripper_position", "position": [500], "block": False}) + '\r\n' + assert result == expected_result + + # @patch("socket.socket") + # def test_get_qpos(self, mock_socket): + # # 模拟网络返回数据 + # mock_socket_instance = MagicMock() + # mock_socket.return_value = mock_socket_instance + # mock_socket_instance.recv.side_effect = [ + # json.dumps({"joint": [1000, 2000, 3000, 4000, 5000, 6000]}).encode('utf-8'), + # json.dumps({"actpos": [700]}).encode('utf-8') + # ] + + # rm_arm = RmArm(self.config_file) + # qpos = rm_arm.get_qpos() + # expected_qpos = { + # "joint_1": 1.0, + # "joint_2": 2.0, + # "joint_3": 3.0, + # "joint_4": 4.0, + # "joint_5": 5.0, + # "joint_6": 6.0, + # "gripper": 700.0 + # } + # assert qpos == expected_qpos + +if __name__ == "__main__": + pytest.main() \ No newline at end of file diff --git a/realman_src/realman_aloha/shadow_rm_robot/test/test_servo_joint_action.py b/realman_src/realman_aloha/shadow_rm_robot/test/test_servo_joint_action.py new file mode 100644 index 000000000..0c9b819e1 --- /dev/null +++ b/realman_src/realman_aloha/shadow_rm_robot/test/test_servo_joint_action.py @@ -0,0 +1,123 @@ +import yaml +import pytest +import binascii +from unittest.mock import patch, MagicMock +from shadow_rm_robot.servo_robotic_arm import ServoArm + + +class TestServoArm: + @pytest.fixture(autouse=True) + def setup_method(self, tmpdir): + # 模拟配置文件 + self.config_data = { + "SerialConfig": { + "port": "/dev/ttyUSB0", + "baudrate": 460800, + "hex_data": "55 AA 02 00 00 67", + } + } + self.config_file = tmpdir.join("test_config.yaml") + with open(self.config_file, "w") as file: + yaml.dump(self.config_data, file) + + # 初始化 ServoArm 对象 + self.servo_arm = ServoArm(self.config_file) + + @patch("serial.Serial") + def test_initialization(self, mock_serial): + # 测试初始化 + mock_serial_instance = MagicMock() + mock_serial.return_value = mock_serial_instance + + servo_arm = ServoArm(self.config_file) + assert servo_arm.port == self.config_data["SerialConfig"]["port"] + assert servo_arm.baudrate == self.config_data["SerialConfig"]["baudrate"] + assert servo_arm.hex_data == self.config_data["SerialConfig"]["hex_data"] + + # 检查串口初始化 + mock_serial.assert_any_call( + self.config_data["SerialConfig"]["port"], + self.config_data["SerialConfig"]["baudrate"], + timeout=0, + ) + + def test_bytes_to_signed_int(self): + # 测试字节转换为有符号整数 + byte_data = b"\x01\x00" + result = self.servo_arm._bytes_to_signed_int(byte_data) + assert result == 1 + + byte_data = b"\xff\xff" + result = self.servo_arm._bytes_to_signed_int(byte_data) + assert result == -1 + + def test_parse_joint_data(self): + # 测试解析关节数据 + hex_received = ( + "00" * 7 + + "01000000" + + "00" * 1 + + "02000000" + + "00" * 1 + + "03000000" + + "00" * 1 + + "04000000" + + "00" * 1 + + "05000000" + + "00" * 1 + + "06000000" + + "00" * 1 + + "07000000" + ) + joints = self.servo_arm._parse_joint_data(hex_received) + expected_joints = { + "joint_1": 0.0001, + "joint_2": 0.0002, + "joint_3": 0.0003, + "joint_4": 0.0004, + "joint_5": 0.0005, + "joint_6": 0.0006, + "grasp": 7, + } + assert joints == expected_joints + + @patch("serial.Serial") + def test_get_joint_actions(self, mock_serial): + # 模拟串口返回数据 + mock_serial_instance = MagicMock() + mock_serial.return_value = mock_serial_instance + mock_serial_instance.read.side_effect = [ + binascii.unhexlify( + "00" * 7 + + "01000000" + + "00" * 1 + + "02000000" + + "00" * 1 + + "03000000" + + "00" * 1 + + "04000000" + + "00" * 1 + + "05000000" + + "00" * 1 + + "06000000" + + "00" * 1 + + "07000000" + ) + ] + + servo_arm = ServoArm(self.config_file) + joint_actions = servo_arm.get_joint_actions() + expected_joint_actions = { + "joint_1": 0.0001, + "joint_2": 0.0002, + "joint_3": 0.0003, + "joint_4": 0.0004, + "joint_5": 0.0005, + "joint_6": 0.0006, + "grasp": 7, + } + assert joint_actions == expected_joint_actions + + +if __name__ == "__main__": + pytest.main() \ No newline at end of file