This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
arm = RmArm("./data/rm_arm.yaml")
|
||||
|
||||
arm.arm.send( '{"command":"set_modbus_mode","port":0,"baudrate":115200,"timeout ":2}\r\n'.encode("utf-8"))
|
||||
|
||||
# arm.arm.send( '{"command":"close_modbus_mode","port":1}\r\n'.encode("utf-8"))
|
||||
|
||||
a = arm.arm.recv(1024)
|
||||
|
||||
logging.debug(a)
|
||||
|
||||
arm.arm.send( '{"command":"read_holding_registers","port":1,"address":14,"device":2}\r\n'.encode("utf-8"))
|
||||
|
||||
b = arm.arm.recv(1024)
|
||||
logging.debug(b)
|
||||
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
class TestRmArm:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, tmpdir):
|
||||
# 模拟配置文件
|
||||
self.config_data = {
|
||||
"arm_ip": "192.168.1.18",
|
||||
"arm_port": 8080
|
||||
}
|
||||
self.config_file = tmpdir.join("test_config.yaml")
|
||||
with open(self.config_file, "w") as file:
|
||||
yaml.dump(self.config_data, file)
|
||||
|
||||
# 初始化 RmArm 对象
|
||||
self.rm_arm = RmArm(self.config_file)
|
||||
|
||||
@patch("socket.socket")
|
||||
def test_initialization(self, mock_socket):
|
||||
# 测试初始化
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
|
||||
rm_arm = RmArm(self.config_file)
|
||||
assert rm_arm.arm_ip == self.config_data["arm_ip"]
|
||||
assert rm_arm.arm_port == self.config_data["arm_port"]
|
||||
|
||||
# 检查网络连接初始化
|
||||
mock_socket_instance.connect.assert_called_with((self.config_data["arm_ip"], self.config_data["arm_port"]))
|
||||
|
||||
def test_json_to_numpy(self):
|
||||
# 测试 JSON 数据解析为 NumPy 数组
|
||||
json_data = json.dumps({"joint": [1, 2, 3, 4, 5, 6]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([1, 2, 3, 4, 5, 6], dtype=float)
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
# 测试键不存在的情况
|
||||
json_data = json.dumps({"other_key": [1, 2, 3]})
|
||||
byte_data = json_data.encode('utf-8')
|
||||
result = self.rm_arm._json_to_numpy(byte_data, 'joint')
|
||||
expected_result = np.array([])
|
||||
np.testing.assert_array_equal(result, expected_result)
|
||||
|
||||
def test_generate_command(self):
|
||||
# 测试生成关节命令
|
||||
data = np.array([0.1, 0.2, 0.3])
|
||||
cmd_type = 'joint'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "movej", "joint": [100, 200, 300], "v": 40, "r": 0}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# 测试生成夹爪命令
|
||||
data = np.array([500])
|
||||
cmd_type = 'gripper'
|
||||
result = self.rm_arm._generate_command(data, cmd_type)
|
||||
expected_result = json.dumps({"command": "set_gripper_position", "position": [500], "block": False}) + '\r\n'
|
||||
assert result == expected_result
|
||||
|
||||
# @patch("socket.socket")
|
||||
# def test_get_qpos(self, mock_socket):
|
||||
# # 模拟网络返回数据
|
||||
# mock_socket_instance = MagicMock()
|
||||
# mock_socket.return_value = mock_socket_instance
|
||||
# mock_socket_instance.recv.side_effect = [
|
||||
# json.dumps({"joint": [1000, 2000, 3000, 4000, 5000, 6000]}).encode('utf-8'),
|
||||
# json.dumps({"actpos": [700]}).encode('utf-8')
|
||||
# ]
|
||||
|
||||
# rm_arm = RmArm(self.config_file)
|
||||
# qpos = rm_arm.get_qpos()
|
||||
# expected_qpos = {
|
||||
# "joint_1": 1.0,
|
||||
# "joint_2": 2.0,
|
||||
# "joint_3": 3.0,
|
||||
# "joint_4": 4.0,
|
||||
# "joint_5": 5.0,
|
||||
# "joint_6": 6.0,
|
||||
# "gripper": 700.0
|
||||
# }
|
||||
# assert qpos == expected_qpos
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
@@ -0,0 +1,123 @@
|
||||
import yaml
|
||||
import pytest
|
||||
import binascii
|
||||
from unittest.mock import patch, MagicMock
|
||||
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||
|
||||
|
||||
class TestServoArm:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, tmpdir):
|
||||
# 模拟配置文件
|
||||
self.config_data = {
|
||||
"SerialConfig": {
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baudrate": 460800,
|
||||
"hex_data": "55 AA 02 00 00 67",
|
||||
}
|
||||
}
|
||||
self.config_file = tmpdir.join("test_config.yaml")
|
||||
with open(self.config_file, "w") as file:
|
||||
yaml.dump(self.config_data, file)
|
||||
|
||||
# 初始化 ServoArm 对象
|
||||
self.servo_arm = ServoArm(self.config_file)
|
||||
|
||||
@patch("serial.Serial")
|
||||
def test_initialization(self, mock_serial):
|
||||
# 测试初始化
|
||||
mock_serial_instance = MagicMock()
|
||||
mock_serial.return_value = mock_serial_instance
|
||||
|
||||
servo_arm = ServoArm(self.config_file)
|
||||
assert servo_arm.port == self.config_data["SerialConfig"]["port"]
|
||||
assert servo_arm.baudrate == self.config_data["SerialConfig"]["baudrate"]
|
||||
assert servo_arm.hex_data == self.config_data["SerialConfig"]["hex_data"]
|
||||
|
||||
# 检查串口初始化
|
||||
mock_serial.assert_any_call(
|
||||
self.config_data["SerialConfig"]["port"],
|
||||
self.config_data["SerialConfig"]["baudrate"],
|
||||
timeout=0,
|
||||
)
|
||||
|
||||
def test_bytes_to_signed_int(self):
|
||||
# 测试字节转换为有符号整数
|
||||
byte_data = b"\x01\x00"
|
||||
result = self.servo_arm._bytes_to_signed_int(byte_data)
|
||||
assert result == 1
|
||||
|
||||
byte_data = b"\xff\xff"
|
||||
result = self.servo_arm._bytes_to_signed_int(byte_data)
|
||||
assert result == -1
|
||||
|
||||
def test_parse_joint_data(self):
|
||||
# 测试解析关节数据
|
||||
hex_received = (
|
||||
"00" * 7
|
||||
+ "01000000"
|
||||
+ "00" * 1
|
||||
+ "02000000"
|
||||
+ "00" * 1
|
||||
+ "03000000"
|
||||
+ "00" * 1
|
||||
+ "04000000"
|
||||
+ "00" * 1
|
||||
+ "05000000"
|
||||
+ "00" * 1
|
||||
+ "06000000"
|
||||
+ "00" * 1
|
||||
+ "07000000"
|
||||
)
|
||||
joints = self.servo_arm._parse_joint_data(hex_received)
|
||||
expected_joints = {
|
||||
"joint_1": 0.0001,
|
||||
"joint_2": 0.0002,
|
||||
"joint_3": 0.0003,
|
||||
"joint_4": 0.0004,
|
||||
"joint_5": 0.0005,
|
||||
"joint_6": 0.0006,
|
||||
"grasp": 7,
|
||||
}
|
||||
assert joints == expected_joints
|
||||
|
||||
@patch("serial.Serial")
|
||||
def test_get_joint_actions(self, mock_serial):
|
||||
# 模拟串口返回数据
|
||||
mock_serial_instance = MagicMock()
|
||||
mock_serial.return_value = mock_serial_instance
|
||||
mock_serial_instance.read.side_effect = [
|
||||
binascii.unhexlify(
|
||||
"00" * 7
|
||||
+ "01000000"
|
||||
+ "00" * 1
|
||||
+ "02000000"
|
||||
+ "00" * 1
|
||||
+ "03000000"
|
||||
+ "00" * 1
|
||||
+ "04000000"
|
||||
+ "00" * 1
|
||||
+ "05000000"
|
||||
+ "00" * 1
|
||||
+ "06000000"
|
||||
+ "00" * 1
|
||||
+ "07000000"
|
||||
)
|
||||
]
|
||||
|
||||
servo_arm = ServoArm(self.config_file)
|
||||
joint_actions = servo_arm.get_joint_actions()
|
||||
expected_joint_actions = {
|
||||
"joint_1": 0.0001,
|
||||
"joint_2": 0.0002,
|
||||
"joint_3": 0.0003,
|
||||
"joint_4": 0.0004,
|
||||
"joint_5": 0.0005,
|
||||
"joint_6": 0.0006,
|
||||
"grasp": 7,
|
||||
}
|
||||
assert joint_actions == expected_joint_actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Reference in New Issue
Block a user