forked from tangger/lerobot
add realman shadow src
This commit is contained in:
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
0
realman_src/realman_aloha/shadow_rm_robot/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_robot/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
@@ -0,0 +1,5 @@
|
||||
port: /dev/ttyUSB0
|
||||
right_port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 7
|
||||
Binary file not shown.
Binary file not shown.
36
realman_src/realman_aloha/shadow_rm_robot/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_robot/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_robot"
|
||||
version = "0.1.0"
|
||||
description = "Robot package, including operations such as reading and controlling robots"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
pyyaml = ">=6.0"
|
||||
pyserial = ">=3.5"
|
||||
pymodbus = ">=3.7"
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# cython: language_level=3
|
||||
import os
|
||||
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
|
||||
class CommonLog(object):
|
||||
"""
|
||||
日志记录
|
||||
"""
|
||||
|
||||
def __init__(self, logger, logname='web-log'):
|
||||
self.logname = os.path.join(os.path.dirname(os.path.abspath(__file__)), '%s' % logname)
|
||||
self.logger = logger
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.logger.propagate = False # 禁止使用logger对象parent的处理器
|
||||
self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def __console(self, level, message):
|
||||
# 创建一个FileHandler,用于写到本地
|
||||
|
||||
# fh = TimedRotatingFileHandler(self.logname, when='MIDNIGHT', interval=1, encoding='utf-8')
|
||||
# # fh = logging.FileHandler(self.logname, 'a', encoding='utf-8')
|
||||
# fh.suffix = '%Y-%m-%d.log'
|
||||
# fh.setLevel(logging.DEBUG)
|
||||
# fh.setFormatter(self.formatter)
|
||||
# self.logger.addHandler(fh)
|
||||
|
||||
# 创建一个StreamHandler,用于输出到控制台
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(self.formatter)
|
||||
self.logger.addHandler(ch)
|
||||
|
||||
if level == 'info':
|
||||
self.logger.info(message)
|
||||
elif level == 'debug':
|
||||
self.logger.debug(message)
|
||||
elif level == 'warning':
|
||||
self.logger.warning(message)
|
||||
elif level == 'error':
|
||||
self.logger.error(message, exc_info=1) # 显示错误栈
|
||||
# self.logger.error(message)
|
||||
|
||||
elif level == 'error_':
|
||||
self.logger.error(message) # 不显示错误栈
|
||||
|
||||
|
||||
# 这两行代码是为了避免日志输出重复问题
|
||||
self.logger.removeHandler(ch)
|
||||
# self.logger.removeHandler(fh)
|
||||
# # 关闭打开的文件
|
||||
# fh.close()
|
||||
|
||||
def debug(self, message):
|
||||
self.__console('debug', message)
|
||||
|
||||
def info(self, message):
|
||||
self.__console('info', message)
|
||||
|
||||
def warning(self, message):
|
||||
self.__console('warning', message)
|
||||
|
||||
def error(self, message):
|
||||
self.__console('error', message)
|
||||
|
||||
def error_(self, message):
|
||||
self.__console('error_', message)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
import json
|
||||
import yaml
|
||||
import time
|
||||
import logging
|
||||
import socket
|
||||
import numpy as np
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class RmArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的网络连接并发送初始命令。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.arm_ip = self.config.get("arm_ip", "192.168.1.18")
|
||||
self.get_vel = self.config.get("get_vel", True)
|
||||
self.get_torque = self.config.get("get_torque", True)
|
||||
arm_port = self.config.get("arm_port", 8080)
|
||||
local_ip = self.config.get("local_ip", '192.168.1.101')
|
||||
local_port = self.config.get("local_port", 8089)
|
||||
|
||||
self.arm = socket.socket()
|
||||
self.arm.connect((self.arm_ip, arm_port))
|
||||
|
||||
set_udp = {"command":"set_realtime_push","cycle":6,"enable":True,"port":local_port,"ip":local_ip,"custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
|
||||
|
||||
self.arm.send(json.dumps(set_udp).encode('utf-8'))
|
||||
_ = self.arm.recv(1024)
|
||||
|
||||
self.arm_axis = self.config.get("arm_axis", 6)
|
||||
self.arm_ki = self.config.get("arm_ki", [7, 7, 7, 3, 3, 3])
|
||||
|
||||
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.udp_socket.bind((local_ip, local_port))
|
||||
|
||||
self.cmd_get_current_arm_state = '{"command":"get_current_arm_state"}\r\n'
|
||||
self.cmd_get_gripper_state = '{"command":"get_gripper_state"}\r\n'
|
||||
|
||||
self.cmd_set_gripper_release = (
|
||||
'{"command": "set_gripper_release", "speed": 500, "block": false}\r\n'
|
||||
)
|
||||
self.cmd_set_gripper_route = (
|
||||
'{"command":"set_gripper_route","min":0,"max":1000}\r\n'
|
||||
)
|
||||
|
||||
self.arm.send(self.cmd_set_gripper_route.encode("utf-8"))
|
||||
_ = self.arm.recv(1024)
|
||||
|
||||
self.pre_gripper_actpos = None
|
||||
self.cur_gripper_actpos = None
|
||||
self.pre_actpos_time = None
|
||||
self.cur_actpos_time = None
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
with open(config_file, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def _json_to_numpy(self, byte_data, key):
|
||||
"""将字节数据解析为 NumPy 数组。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
key (str): JSON 数据中的键。
|
||||
|
||||
Returns:
|
||||
np.ndarray: 解析后的 NumPy 数组。
|
||||
"""
|
||||
str_data = byte_data.decode("utf-8")
|
||||
logging.debug(f"Received KEY: {key}")
|
||||
logging.debug(f"Received JSON data: {str_data}")
|
||||
try:
|
||||
data_list = json.loads(str_data)[key]
|
||||
if isinstance(data_list, dict):
|
||||
return data_list
|
||||
except KeyError:
|
||||
logging.error(f"Key '{key}' not found in JSON data")
|
||||
logging.error(f"Received JSON data: {str_data}")
|
||||
return None
|
||||
return np.array(data_list, dtype=float)
|
||||
|
||||
def set_joint_position(self, joint_angle):
|
||||
"""设置机械臂的位置。
|
||||
|
||||
Args:
|
||||
arm_pos (np.ndarray): 机械臂的位置
|
||||
|
||||
"""
|
||||
joint_angle = np.array(joint_angle)
|
||||
data = np.floor(joint_angle * 1000).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps(
|
||||
{"command": "movej", "joint": data, "block": True, "v": 40, "r": 0}
|
||||
)
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
# TODO: Pending
|
||||
state = self.arm.recv(1024)
|
||||
# state = self.arm.recv(1024)
|
||||
|
||||
def set_joint_canfd_position(self, joint_angle):
|
||||
"""设置机械臂的位置。
|
||||
|
||||
Args:
|
||||
arm_pos (np.ndarray): 机械臂的位置
|
||||
"""
|
||||
joint_angle = np.array(joint_angle)
|
||||
data = np.floor(joint_angle * 1000).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps({"command": "movej_canfd", "joint": data, "follow": False})
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
|
||||
def set_gripper_position(self, actpos):
|
||||
"""设置夹爪的位置。
|
||||
|
||||
Args:
|
||||
actpos (np.ndarray): 夹爪的位置,单位为毫米。
|
||||
"""
|
||||
data = np.array(actpos) * 1000
|
||||
data = np.floor(data).astype(int).tolist()
|
||||
cmd = (
|
||||
json.dumps(
|
||||
{
|
||||
"command": "set_gripper_position",
|
||||
"position": data,
|
||||
"block": False,
|
||||
}
|
||||
)
|
||||
+ "\r\n"
|
||||
)
|
||||
self.arm.send(cmd.encode("utf-8"))
|
||||
aaa = self.arm.recv(1024)
|
||||
# print(aaa)
|
||||
|
||||
def _update_state(self, gripper_actpos, actpos_time):
|
||||
"""更新关节和夹爪状态及时间。
|
||||
|
||||
Args:
|
||||
actpos_time (float): 夹爪时间戳。
|
||||
"""
|
||||
self.pre_gripper_actpos, self.cur_gripper_actpos = self.cur_gripper_actpos, gripper_actpos
|
||||
self.pre_actpos_time, self.cur_actpos_time = self.cur_actpos_time, actpos_time
|
||||
|
||||
def get_arm_data(self):
|
||||
"""获取机械臂数据"""
|
||||
data, addr = self.udp_socket.recvfrom(1024)
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
# logging.info(f"Received data: {data}")
|
||||
joint_angle = np.array(data['joint_status']['joint_position']) * 0.001
|
||||
joint_velocity = np.array(data['joint_status']['joint_speed']) * 0.001 if self.get_vel else None
|
||||
joint_current = np.array(data['joint_status']['joint_current']) / 1000000 if self.get_torque else None
|
||||
joint_torque = self.current_to_torque(joint_current) if self.get_torque else None
|
||||
aloha_state = data['aloha_state']
|
||||
# logging.info(f"Time consumed: {time.time() - start_time}")
|
||||
|
||||
result = {'joint_angle': joint_angle, 'aloha_state': aloha_state}
|
||||
if self.get_vel:
|
||||
result['joint_velocity'] = joint_velocity
|
||||
if self.get_torque:
|
||||
result['joint_torque'] = joint_torque
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_gripper_data(self):
|
||||
"""获取夹爪数据"""
|
||||
try:
|
||||
actpos_time = time.time()
|
||||
self.arm.send(self.cmd_get_gripper_state.encode("utf-8"))
|
||||
# gripper_qpos = self.arm.recv(1024)
|
||||
while True:
|
||||
gripper_qpos = self.arm.recv(1024)
|
||||
data = json.loads(gripper_qpos.decode("utf-8"))
|
||||
if "actpos" in data:
|
||||
break
|
||||
else:
|
||||
self.arm.send(self.cmd_get_gripper_state.encode("utf-8"))
|
||||
gripper_actpos = self._json_to_numpy(gripper_qpos, "actpos") * 0.001
|
||||
gripper_velocity = self.get_gripper_velocity() if self.get_vel else None
|
||||
gripper_force = self._json_to_numpy(gripper_qpos, "current_force") / 100 if self.get_torque else None
|
||||
|
||||
result = {'gripper_actpos': gripper_actpos}
|
||||
if self.get_vel:
|
||||
result['gripper_velocity'] = gripper_velocity
|
||||
if self.get_torque:
|
||||
result['gripper_force'] = gripper_force
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting gripper data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_integrate_data(self):
|
||||
"""获取整合数据"""
|
||||
arm_data = self.get_arm_data()
|
||||
gripper_data = self.get_gripper_data()
|
||||
|
||||
if not arm_data or not gripper_data:
|
||||
return None
|
||||
|
||||
result = {'aloha_state': arm_data['aloha_state'], 'arm_angle': np.append(arm_data['joint_angle'], gripper_data['gripper_actpos'])}
|
||||
if self.get_vel:
|
||||
result['arm_velocity'] = np.append(arm_data['joint_velocity'], gripper_data['gripper_velocity'])
|
||||
if self.get_torque:
|
||||
result['arm_torque'] = arm_data['joint_torque'] + [gripper_data['gripper_force']]
|
||||
|
||||
return result
|
||||
|
||||
def get_gripper_velocity(self):
|
||||
"""获取夹爪速度"""
|
||||
if self.pre_actpos_time is None or self.cur_actpos_time is None:
|
||||
logging.debug("Previous or current joint positions are not available.")
|
||||
return 0
|
||||
delta_time = self.cur_actpos_time - self.pre_actpos_time
|
||||
return (self.cur_gripper_actpos - self.pre_gripper_actpos["gripper"]) / delta_time if self.cur_gripper_actpos is not None else 0
|
||||
|
||||
def current_to_torque(self, current):
|
||||
"""将电流转换为扭矩"""
|
||||
return [c * k for c, k in zip(current, self.arm_ki)]
|
||||
|
||||
def get_arm_position(self):
|
||||
"""获取机械臂的位置。
|
||||
|
||||
Returns:
|
||||
dict: 机械臂的位置,单位为毫米和弧度。
|
||||
包含以下键:
|
||||
- 'x': x 轴位置
|
||||
- 'y': y 轴位置
|
||||
- 'z': z 轴位置
|
||||
- 'roll': 滚转角
|
||||
- 'pitch': 俯仰角
|
||||
- 'yaw': 偏航角
|
||||
- 单位 : mm, rad
|
||||
"""
|
||||
self.arm.send(self.cmd_get_current_arm_state.encode("utf-8"))
|
||||
_arm_state = self.arm.recv(1024)
|
||||
arm_state = self._json_to_numpy(_arm_state, "arm_state")
|
||||
arm_pos = np.array(arm_state["pose"], dtype=float) * 0.001
|
||||
|
||||
return {
|
||||
"x": arm_pos[0],
|
||||
"y": arm_pos[1],
|
||||
"z": arm_pos[2],
|
||||
"roll": arm_pos[3],
|
||||
"pitch": arm_pos[4],
|
||||
"yaw": arm_pos[5],
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
arm_left = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
# arm_right = RmArm("/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml")
|
||||
# test_left_narry = [7.235, 31.816, 51.237, 2.463, 91.054, 12.04]
|
||||
test_right_narry = [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868]
|
||||
while True:
|
||||
start_time = time.time()
|
||||
arm_left.set_gripper_position(0.2)
|
||||
# left_qpos = arm_left.get_integrate_data()
|
||||
left_qpos = arm_left.get_gripper_data()
|
||||
logging.info(left_qpos)
|
||||
# right_qpos = arm_right.get_arm_data()
|
||||
# logging.info(left_qpos)
|
||||
# logging.info(right_qpos)
|
||||
# time.sleep(0.02)
|
||||
# arm_right.set_joint_canfd_position(test_right_narry)
|
||||
# arm_right.set_gripper_position(0.2)
|
||||
|
||||
|
||||
logging.info(f"Time consumed: {time.time() - start_time}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import yaml
|
||||
import serial
|
||||
import logging
|
||||
import binascii
|
||||
import numpy as np
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class ServoArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的串口连接并发送初始数据。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.port = self.config["port"]
|
||||
self.baudrate = self.config["baudrate"]
|
||||
self.hex_data = self.config["hex_data"]
|
||||
self.arm_axis = self.config.get("arm_axis", 7)
|
||||
|
||||
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||
|
||||
self.bytes_to_send = binascii.unhexlify(self.hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(1)
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
with open(config_file, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
|
||||
def _bytes_to_signed_int(self, byte_data):
|
||||
"""将字节数据转换为有符号整数。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
|
||||
Returns:
|
||||
int: 有符号整数。
|
||||
"""
|
||||
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||
|
||||
def _parse_joint_data(self, hex_received):
|
||||
"""解析接收到的十六进制数据并提取关节数据。
|
||||
|
||||
Args:
|
||||
hex_received (str): 接收到的十六进制字符串数据。
|
||||
|
||||
Returns:
|
||||
dict: 解析后的关节数据。
|
||||
"""
|
||||
logging.debug(f"hex_received: {hex_received}")
|
||||
joints = {}
|
||||
for i in range(self.arm_axis):
|
||||
start = 14 + i * 10
|
||||
end = start + 8
|
||||
joint_hex = hex_received[start:end]
|
||||
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||
joints[f"joint_{i+1}"] = joint_value
|
||||
grasp_start = 14 + self.arm_axis*10
|
||||
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||
# 夹爪进行归一化处理
|
||||
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||
|
||||
joints["grasp"] = grasp_value
|
||||
return joints
|
||||
|
||||
def get_joint_actions(self):
|
||||
"""从串口读取数据并解析关节动作。
|
||||
|
||||
Returns:
|
||||
dict: 包含关节数据的字典。
|
||||
"""
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
actions = self._parse_joint_data(hex_received)
|
||||
return actions
|
||||
def set_gripper_action(self, action):
|
||||
"""设置夹爪动作。
|
||||
|
||||
Args:
|
||||
action (int): 夹爪动作值。
|
||||
"""
|
||||
action = int(action * 1000)
|
||||
action_bytes = action.to_bytes(4, byteorder="little", signed=True)
|
||||
self.bytes_to_send = self.bytes_to_send[:74] + action_bytes + self.bytes_to_send[78:]
|
||||
|
||||
if __name__ == "__main__":
|
||||
servo_arm = ServoArm("/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml")
|
||||
while True:
|
||||
joint_actions = servo_arm.get_joint_actions()
|
||||
logging.info(joint_actions)
|
||||
time.sleep(1)
|
||||
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
arm = RmArm("./data/rm_arm.yaml")
|
||||
|
||||
arm.arm.send( '{"command":"set_modbus_mode","port":0,"baudrate":115200,"timeout ":2}\r\n'.encode("utf-8"))
|
||||
|
||||
# arm.arm.send( '{"command":"close_modbus_mode","port":1}\r\n'.encode("utf-8"))
|
||||
|
||||
a = arm.arm.recv(1024)
|
||||
|
||||
logging.debug(a)
|
||||
|
||||
arm.arm.send( '{"command":"read_holding_registers","port":1,"address":14,"device":2}\r\n'.encode("utf-8"))
|
||||
|
||||
b = arm.arm.recv(1024)
|
||||
logging.debug(b)
|
||||
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
class TestRmArm:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, tmpdir):
|
||||
# 模拟配置文件
|
||||
self.config_data = {
|
||||
"arm_ip": "192.168.1.18",
|
||||
"arm_port": 8080
|
||||
}
|
||||
self.config_file = tmpdir.join("test_config.yaml")
|
||||
with open(self.config_file, "w") as file:
|
||||
yaml.dump(self.config_data, file)
|
||||
|
||||
# 初始化 RmArm 对象
|
||||
self.rm_arm = RmArm(self.config_file)
|
||||
|
||||
@patch("socket.socket")
|
||||
def test_initialization(self, mock_socket):
|
||||
# 测试初始化
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
|
||||
rm_arm = RmArm(self.config_file)
|
||||
assert rm_arm.arm_ip == self.config_data["arm_ip"]
|
||||
assert rm_arm.arm_port == self.config_data["arm_port"]
|
||||
|
||||
# 检查网络连接初始化
|
||||
mock_socket_instance.connect.assert_called_with((self.config_data["arm_ip"], self.config_data["arm_port"]))
|
||||
|
||||
def test_json_to_numpy(self):
|
||||
# 测试 JSON 数据解析为 NumPy 数组
|
||||
json_data = json.dumps({"joint": [1, 2, 3, 4, 5, 6]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([1, 2, 3, 4, 5, 6], dtype=float)
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
# 测试键不存在的情况
|
||||
json_data = json.dumps({"other_key": [1, 2, 3]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([])
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
def test_generate_command(self):
|
||||
# 测试生成关节命令
|
||||
data = np.array([0.1, 0.2, 0.3])
|
||||
cmd_type = 'joint'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "movej", "joint": [100, 200, 300], "v": 40, "r": 0}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# 测试生成夹爪命令
|
||||
data = np.array([500])
|
||||
cmd_type = 'gripper'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "set_gripper_position", "position": [500], "block": False}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# @patch("socket.socket")
|
||||
# def test_get_qpos(self, mock_socket):
|
||||
# # 模拟网络返回数据
|
||||
# mock_socket_instance = MagicMock()
|
||||
# mock_socket.return_value = mock_socket_instance
|
||||
# mock_socket_instance.recv.side_effect = [
|
||||
# json.dumps({"joint": [1000, 2000, 3000, 4000, 5000, 6000]}).encode('utf-8'),
|
||||
# json.dumps({"actpos": [700]}).encode('utf-8')
|
||||
# ]
|
||||
|
||||
# rm_arm = RmArm(self.config_file)
|
||||
# qpos = rm_arm.get_qpos()
|
||||
# expected_qpos = {
|
||||
# "joint_1": 1.0,
|
||||
# "joint_2": 2.0,
|
||||
# "joint_3": 3.0,
|
||||
# "joint_4": 4.0,
|
||||
# "joint_5": 5.0,
|
||||
# "joint_6": 6.0,
|
||||
# "gripper": 700.0
|
||||
# }
|
||||
# assert qpos == expected_qpos
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
@@ -0,0 +1,123 @@
|
||||
import yaml
|
||||
import pytest
|
||||
import binascii
|
||||
from unittest.mock import patch, MagicMock
|
||||
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||
|
||||
|
||||
class TestServoArm:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, tmpdir):
|
||||
# 模拟配置文件
|
||||
self.config_data = {
|
||||
"SerialConfig": {
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baudrate": 460800,
|
||||
"hex_data": "55 AA 02 00 00 67",
|
||||
}
|
||||
}
|
||||
self.config_file = tmpdir.join("test_config.yaml")
|
||||
with open(self.config_file, "w") as file:
|
||||
yaml.dump(self.config_data, file)
|
||||
|
||||
# 初始化 ServoArm 对象
|
||||
self.servo_arm = ServoArm(self.config_file)
|
||||
|
||||
@patch("serial.Serial")
|
||||
def test_initialization(self, mock_serial):
|
||||
# 测试初始化
|
||||
mock_serial_instance = MagicMock()
|
||||
mock_serial.return_value = mock_serial_instance
|
||||
|
||||
servo_arm = ServoArm(self.config_file)
|
||||
assert servo_arm.port == self.config_data["SerialConfig"]["port"]
|
||||
assert servo_arm.baudrate == self.config_data["SerialConfig"]["baudrate"]
|
||||
assert servo_arm.hex_data == self.config_data["SerialConfig"]["hex_data"]
|
||||
|
||||
# 检查串口初始化
|
||||
mock_serial.assert_any_call(
|
||||
self.config_data["SerialConfig"]["port"],
|
||||
self.config_data["SerialConfig"]["baudrate"],
|
||||
timeout=0,
|
||||
)
|
||||
|
||||
def test_bytes_to_signed_int(self):
|
||||
# 测试字节转换为有符号整数
|
||||
byte_data = b"\x01\x00"
|
||||
result = self.servo_arm._bytes_to_signed_int(byte_data)
|
||||
assert result == 1
|
||||
|
||||
byte_data = b"\xff\xff"
|
||||
result = self.servo_arm._bytes_to_signed_int(byte_data)
|
||||
assert result == -1
|
||||
|
||||
def test_parse_joint_data(self):
|
||||
# 测试解析关节数据
|
||||
hex_received = (
|
||||
"00" * 7
|
||||
+ "01000000"
|
||||
+ "00" * 1
|
||||
+ "02000000"
|
||||
+ "00" * 1
|
||||
+ "03000000"
|
||||
+ "00" * 1
|
||||
+ "04000000"
|
||||
+ "00" * 1
|
||||
+ "05000000"
|
||||
+ "00" * 1
|
||||
+ "06000000"
|
||||
+ "00" * 1
|
||||
+ "07000000"
|
||||
)
|
||||
joints = self.servo_arm._parse_joint_data(hex_received)
|
||||
expected_joints = {
|
||||
"joint_1": 0.0001,
|
||||
"joint_2": 0.0002,
|
||||
"joint_3": 0.0003,
|
||||
"joint_4": 0.0004,
|
||||
"joint_5": 0.0005,
|
||||
"joint_6": 0.0006,
|
||||
"grasp": 7,
|
||||
}
|
||||
assert joints == expected_joints
|
||||
|
||||
@patch("serial.Serial")
|
||||
def test_get_joint_actions(self, mock_serial):
|
||||
# 模拟串口返回数据
|
||||
mock_serial_instance = MagicMock()
|
||||
mock_serial.return_value = mock_serial_instance
|
||||
mock_serial_instance.read.side_effect = [
|
||||
binascii.unhexlify(
|
||||
"00" * 7
|
||||
+ "01000000"
|
||||
+ "00" * 1
|
||||
+ "02000000"
|
||||
+ "00" * 1
|
||||
+ "03000000"
|
||||
+ "00" * 1
|
||||
+ "04000000"
|
||||
+ "00" * 1
|
||||
+ "05000000"
|
||||
+ "00" * 1
|
||||
+ "06000000"
|
||||
+ "00" * 1
|
||||
+ "07000000"
|
||||
)
|
||||
]
|
||||
|
||||
servo_arm = ServoArm(self.config_file)
|
||||
joint_actions = servo_arm.get_joint_actions()
|
||||
expected_joint_actions = {
|
||||
"joint_1": 0.0001,
|
||||
"joint_2": 0.0002,
|
||||
"joint_3": 0.0003,
|
||||
"joint_4": 0.0004,
|
||||
"joint_5": 0.0005,
|
||||
"joint_6": 0.0006,
|
||||
"grasp": 7,
|
||||
}
|
||||
assert joint_actions == expected_joint_actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Reference in New Issue
Block a user