forked from tangger/lerobot
Compare commits
23 Commits
main
...
realman-du
| Author | SHA1 | Date | |
|---|---|---|---|
| 3685542bf1 | |||
| 7c1699898b | |||
| b3e9e11e11 | |||
| b04e6e0c7b | |||
| 96804bc86c | |||
| ef45ea9649 | |||
| bc351a0134 | |||
| 68986f6fc0 | |||
| 2f124e34de | |||
| c28e774234 | |||
| 80b1a97e4c | |||
| f4fec8f51c | |||
| f4f82c916f | |||
| ecbe154709 | |||
| d00c154db9 | |||
| 55f284b306 | |||
| cf8df17d3a | |||
| e079566597 | |||
| 83d6419d70 | |||
| a0ec9e1cb1 | |||
| 3eede4447d | |||
| 9c6a7d9701 | |||
| 7b201773f3 |
@@ -58,7 +58,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||
if not robot.robot_type.startswith("stretch"):
|
||||
if not robot.robot_type.startswith(("stretch", "realman")):
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
|
||||
@@ -39,3 +39,24 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("realman")
|
||||
@dataclass
|
||||
class RealmanMotorsBusConfig(MotorsBusConfig):
|
||||
ip: str
|
||||
port: int
|
||||
motors: dict[str, tuple[int, str]]
|
||||
init_joint: dict[str, list]
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("realman_dual")
|
||||
@dataclass
|
||||
class RealmanDualMotorsBusConfig(MotorsBusConfig):
|
||||
left_ip: str
|
||||
right_ip: str
|
||||
left_port: int
|
||||
right_port: int
|
||||
motors: dict[str, tuple[int, str]]
|
||||
init_joint: dict[str, list]
|
||||
axis: dict[str, int]
|
||||
150
lerobot/common/robot_devices/motors/realman.py
Normal file
150
lerobot/common/robot_devices/motors/realman.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import time
|
||||
from typing import Dict
|
||||
from lerobot.common.robot_devices.motors.configs import RealmanMotorsBusConfig
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
|
||||
class RealmanMotorsBus:
|
||||
"""
|
||||
对Realman SDK的二次封装
|
||||
"""
|
||||
def __init__(self,
|
||||
config: RealmanMotorsBusConfig):
|
||||
self.rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
self.handle = self.rmarm.rm_create_robot_arm(config.ip, config.port)
|
||||
self.motors = config.motors
|
||||
self.init_joint_position = config.init_joint['joint'] # [6 joints + 1 gripper]
|
||||
self.safe_disable_position = config.init_joint['joint']
|
||||
self.rmarm.rm_movej(self.init_joint_position[:-1], 5, 0, 0, 1)
|
||||
time.sleep(3)
|
||||
ret = self.rmarm.rm_get_current_arm_state()
|
||||
self.init_pose = ret[1]['pose']
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
|
||||
def connect(self, enable=True) -> bool:
|
||||
'''
|
||||
使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序
|
||||
'''
|
||||
enable_flag = False
|
||||
loop_flag = False
|
||||
# 设置超时时间(秒)
|
||||
timeout = 5
|
||||
# 记录进入循环前的时间
|
||||
start_time = time.time()
|
||||
elapsed_time_flag = False
|
||||
|
||||
while not loop_flag:
|
||||
elapsed_time = time.time() - start_time
|
||||
print("--------------------")
|
||||
|
||||
if enable:
|
||||
# 获取机械臂状态
|
||||
ret = self.rmarm.rm_get_current_arm_state()
|
||||
if ret[0] == 0: # 成功获取状态
|
||||
enable_flag = True
|
||||
else:
|
||||
enable_flag = False
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
print("使能状态:", enable_flag)
|
||||
print("--------------------")
|
||||
if(enable_flag == enable):
|
||||
loop_flag = True
|
||||
enable_flag = True
|
||||
else:
|
||||
loop_flag = False
|
||||
enable_flag = False
|
||||
# 检查是否超过超时时间
|
||||
if elapsed_time > timeout:
|
||||
print("超时....")
|
||||
elapsed_time_flag = True
|
||||
enable_flag = True
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
resp = enable_flag
|
||||
print(f"Returning response: {resp}")
|
||||
return resp
|
||||
|
||||
def motor_names(self):
|
||||
return
|
||||
|
||||
def set_calibration(self):
|
||||
return
|
||||
|
||||
def revert_calibration(self):
|
||||
return
|
||||
|
||||
def apply_calibration(self):
|
||||
"""
|
||||
移动到初始位置
|
||||
"""
|
||||
self.write(target_joint=self.init_joint_position)
|
||||
|
||||
def write(self, target_joint:list):
|
||||
# self.rmarm.rm_movej(target_joint[:-1], 50, 0, 0, 1)
|
||||
self.rmarm.rm_movej_follow(target_joint[:-1])
|
||||
self.rmarm.rm_set_gripper_position(target_joint[-1], block=False, timeout=2)
|
||||
|
||||
def write_endpose(self, target_endpose: list, gripper: int):
|
||||
self.rmarm.rm_movej_p(target_endpose, 50, 0, 0, 1)
|
||||
self.rmarm.rm_set_gripper_position(gripper, block=False, timeout=2)
|
||||
|
||||
def write_joint_slow(self, target_joint: list):
|
||||
self.rmarm.rm_movej(target_joint, 5, 0, 0, 0)
|
||||
|
||||
def write_joint_canfd(self, target_joint: list):
|
||||
self.rmarm.rm_movej_canfd(target_joint, False)
|
||||
|
||||
def write_endpose_canfd(self, target_pose: list):
|
||||
self.rmarm.rm_movep_canfd(target_pose, False)
|
||||
|
||||
def write_gripper(self, gripper: int):
|
||||
self.rmarm.rm_set_gripper_position(gripper, False, 2)
|
||||
|
||||
def read(self) -> Dict:
|
||||
"""
|
||||
- 机械臂关节消息,单位1度;[-1, 1]
|
||||
- 机械臂夹爪消息,[-1, 1]
|
||||
"""
|
||||
joint_msg = self.rmarm.rm_get_current_arm_state()[1]
|
||||
joint_state = joint_msg['joint']
|
||||
|
||||
gripper_msg = self.rmarm.rm_get_gripper_state()[1]
|
||||
gripper_state = gripper_msg['actpos']
|
||||
|
||||
return {
|
||||
"joint_1": joint_state[0]/180,
|
||||
"joint_2": joint_state[1]/180,
|
||||
"joint_3": joint_state[2]/180,
|
||||
"joint_4": joint_state[3]/180,
|
||||
"joint_5": joint_state[4]/180,
|
||||
"joint_6": joint_state[5]/180,
|
||||
"gripper": (gripper_state-500)/500
|
||||
}
|
||||
|
||||
def read_current_arm_joint_state(self):
|
||||
return self.rmarm.rm_get_current_arm_state()[1]['joint']
|
||||
|
||||
def read_current_arm_endpose_state(self):
|
||||
return self.rmarm.rm_get_current_arm_state()[1]['pose']
|
||||
|
||||
def safe_disconnect(self):
|
||||
"""
|
||||
Move to safe disconnect position
|
||||
"""
|
||||
self.write(target_joint=self.safe_disable_position)
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
350
lerobot/common/robot_devices/motors/realman_dual.py
Normal file
350
lerobot/common/robot_devices/motors/realman_dual.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
from lerobot.common.robot_devices.motors.configs import RealmanDualMotorsBusConfig
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
|
||||
def compare_joint_difference(master_joints, follow_joints, threshold=30.0):
|
||||
"""
|
||||
比较主臂和从臂关节数据的差异
|
||||
|
||||
Args:
|
||||
master_joints (list): 主臂关节数据 [joint1, joint2, ..., joint6]
|
||||
follow_joints (list): 从臂关节数据 [joint1, joint2, ..., joint6]
|
||||
threshold (float): 差异阈值(度),默认5度
|
||||
|
||||
Returns:
|
||||
bool: True表示差异在阈值内,False表示超过阈值
|
||||
"""
|
||||
# 检查数据长度
|
||||
if len(master_joints) != len(follow_joints):
|
||||
return False
|
||||
|
||||
# 计算每个关节的绝对差异
|
||||
for i in range(len(master_joints)):
|
||||
diff = abs(master_joints[i] - follow_joints[i])
|
||||
if diff > threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmState:
|
||||
"""机械臂状态数据类"""
|
||||
joint_positions: list
|
||||
gripper_position: int
|
||||
pose: list
|
||||
|
||||
|
||||
class RealmanDualMotorsBus:
|
||||
"""
|
||||
对Realman SDK的二次封装
|
||||
"""
|
||||
def __init__(self, config: RealmanDualMotorsBusConfig):
|
||||
self.config = config
|
||||
self._initialize_arms()
|
||||
self._initialize_parameters()
|
||||
self._initialize_positions()
|
||||
self._initialize_threading()
|
||||
|
||||
def _initialize_arms(self):
|
||||
"""初始化机械臂连接"""
|
||||
self.left_rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
self.right_rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
self.handle_left = self.left_rmarm.rm_create_robot_arm(
|
||||
self.config.left_ip, self.config.left_port
|
||||
)
|
||||
self.handle_right = self.right_rmarm.rm_create_robot_arm(
|
||||
self.config.right_ip, self.config.right_port
|
||||
)
|
||||
|
||||
def _initialize_parameters(self):
|
||||
"""初始化参数"""
|
||||
self.motors = self.config.motors
|
||||
self.axis = self.config.axis
|
||||
self.joint_count = sum(self.axis.values())
|
||||
self.left_offset = self.axis['left_joint']
|
||||
|
||||
def _initialize_positions(self):
|
||||
"""初始化位置"""
|
||||
self.init_joint_position = self.config.init_joint['joint']
|
||||
self.safe_disable_position = self.config.init_joint['joint']
|
||||
|
||||
# 移动到初始位置
|
||||
self._move_to_initial_position()
|
||||
|
||||
# 获取初始姿态
|
||||
time.sleep(3)
|
||||
self.init_pose = self._get_initial_pose()
|
||||
|
||||
def _initialize_threading(self):
|
||||
"""初始化线程控制"""
|
||||
self.left_slow_busy = False
|
||||
self.right_slow_busy = False
|
||||
self.gripper_busy = False
|
||||
self._thread_lock = threading.Lock()
|
||||
|
||||
# 添加读取相关的线程控制
|
||||
self._state_cache = {"joint": {}, "pose": {}}
|
||||
self._cache_lock = threading.Lock()
|
||||
self._keep_reading = True
|
||||
|
||||
# 启动后台读取线程
|
||||
self._start_background_readers()
|
||||
|
||||
def _start_background_readers(self):
|
||||
"""启动后台读取线程"""
|
||||
# 读取线程
|
||||
threading.Thread(
|
||||
target=self._read_task,
|
||||
daemon=True,
|
||||
name="arm_reader"
|
||||
).start()
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
@contextmanager
|
||||
def _timeout_context(self, timeout: float = 5.0):
|
||||
"""超时上下文管理器"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
yield lambda: time.time() - start_time < timeout
|
||||
except Exception as e:
|
||||
raise TimeoutError(f"操作超时: {e}")
|
||||
|
||||
def _read_task(self):
|
||||
"""左臂后台读取任务 - 模仿_left_slow_task的风格"""
|
||||
while self._keep_reading:
|
||||
try:
|
||||
left_state = self._read_arm_state(self.left_rmarm, "left")
|
||||
with self._cache_lock:
|
||||
self._state_cache["joint"].update(left_state["joint"])
|
||||
self._state_cache["pose"].update(left_state["pose"])
|
||||
except Exception as e:
|
||||
print(f"左臂读取失败: {e}")
|
||||
|
||||
try:
|
||||
right_state = self._read_arm_state(self.right_rmarm, "right")
|
||||
with self._cache_lock:
|
||||
self._state_cache["joint"].update(right_state["joint"])
|
||||
self._state_cache["pose"].update(right_state["pose"])
|
||||
except Exception as e:
|
||||
print(f"右臂读取失败: {e}")
|
||||
|
||||
def _read_arm_state(self, arm: RoboticArm, prefix: str) -> dict:
|
||||
"""读取单臂状态 - 保持原有逻辑"""
|
||||
joint_msg = arm.rm_get_current_arm_state()[1]
|
||||
gripper_msg = arm.rm_get_gripper_state()[1]
|
||||
|
||||
joint_state = joint_msg['joint']
|
||||
gripper_state = gripper_msg['actpos']
|
||||
pose_state = joint_msg['pose']
|
||||
|
||||
joint_state_dict = {}
|
||||
for i in range(len(joint_state)):
|
||||
joint_state_dict[f"{prefix}_joint_{i+1}"] = joint_state[i]
|
||||
joint_state_dict[f"{prefix}_gripper"] = gripper_state
|
||||
|
||||
pose_state_dict = {
|
||||
f"{prefix}_x": pose_state[0],
|
||||
f"{prefix}_y": pose_state[1],
|
||||
f"{prefix}_z": pose_state[2],
|
||||
f"{prefix}_rx": pose_state[3],
|
||||
f"{prefix}_ry": pose_state[4],
|
||||
f"{prefix}_rz": pose_state[5],
|
||||
}
|
||||
|
||||
return {"joint": joint_state_dict, 'pose': pose_state_dict}
|
||||
|
||||
def _move_to_initial_position(self):
|
||||
"""移动到初始位置"""
|
||||
left_joints = self.init_joint_position[:self.left_offset]
|
||||
right_joints = self.init_joint_position[self.left_offset+1:-1]
|
||||
|
||||
self.left_rmarm.rm_movej(left_joints, 5, 0, 0, 1)
|
||||
self.right_rmarm.rm_movej(right_joints, 5, 0, 0, 1)
|
||||
|
||||
def _get_initial_pose(self) -> list:
|
||||
"""获取初始姿态"""
|
||||
left_ret = self.left_rmarm.rm_get_current_arm_state()
|
||||
right_ret = self.right_rmarm.rm_get_current_arm_state()
|
||||
return left_ret[1]['pose'] + right_ret[1]['pose']
|
||||
|
||||
def _validate_joint_count(self, joints: list, expected_count: int):
|
||||
"""验证关节数量"""
|
||||
if len(joints) != expected_count:
|
||||
raise ValueError(f"关节数量不匹配: 期望 {expected_count}, 实际 {len(joints)}")
|
||||
|
||||
def _execute_slow_movement(self, arm: str, joint_data: list):
|
||||
"""执行慢速运动"""
|
||||
busy_flag = f"{arm}_slow_busy"
|
||||
|
||||
if not getattr(self, busy_flag):
|
||||
setattr(self, busy_flag, True)
|
||||
|
||||
target_method = getattr(self, f"_{arm}_slow_task")
|
||||
threading.Thread(
|
||||
target=target_method,
|
||||
args=(joint_data.copy(),),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
def _left_slow_task(self, joint_data: list):
|
||||
"""左臂慢速任务"""
|
||||
try:
|
||||
self.write_left_joint_slow(joint_data)
|
||||
finally:
|
||||
self.left_slow_busy = False
|
||||
|
||||
def _right_slow_task(self, joint_data: list):
|
||||
"""右臂慢速任务"""
|
||||
try:
|
||||
self.write_right_joint_slow(joint_data)
|
||||
finally:
|
||||
self.right_slow_busy = False
|
||||
|
||||
def _execute_arm_action(self, arm: str, action: dict, master_joint: list, follow_joint: list):
|
||||
"""执行单臂动作"""
|
||||
controller_status = action['master_controller_status'][arm]
|
||||
|
||||
if controller_status['infrared'] == 1:
|
||||
if compare_joint_difference(master_joint, follow_joint):
|
||||
if arm == 'left':
|
||||
self.write_left_joint_canfd(master_joint)
|
||||
else:
|
||||
self.write_right_joint_canfd(master_joint)
|
||||
else:
|
||||
self._execute_slow_movement(arm, master_joint)
|
||||
|
||||
def write_endpose(self, target_endpose: list):
|
||||
assert target_endpose == 12, "the length of target pose is not equal 12"
|
||||
self.left_rmarm.rm_movej_p(target_endpose[:6], 50, 0, 0, 1)
|
||||
self.right_rmarm.rm_movej_p(target_endpose[6:], 50, 0, 0, 1)
|
||||
|
||||
def write_left_joint_slow(self, left_joint: list):
|
||||
assert len(left_joint) == self.left_offset, "len of left master joint is not equal the count of left joint"
|
||||
self.left_rmarm.rm_movej(left_joint, 5, 0, 0, 1)
|
||||
|
||||
def write_right_joint_slow(self, right_joint: list):
|
||||
assert len(right_joint) == self.left_offset, "len of right master joint is not equal the count of right joint"
|
||||
self.right_rmarm.rm_movej(right_joint, 5, 0, 0, 1)
|
||||
|
||||
def write_left_joint_canfd(self, left_joint: list):
|
||||
assert len(left_joint) == self.left_offset, "len of left master joint is not equal the count of left joint"
|
||||
self.left_rmarm.rm_movej_canfd(left_joint, False)
|
||||
|
||||
def write_right_joint_canfd(self, right_joint: list):
|
||||
assert len(right_joint) == self.left_offset, "len of right master joint is not equal the count of right joint"
|
||||
self.right_rmarm.rm_movej_canfd(right_joint, False)
|
||||
|
||||
def write_endpose_canfd(self, target_endpose: list):
|
||||
assert len(target_endpose) == 12, "the length of target pose is not equal 12"
|
||||
self.left_rmarm.rm_movep_canfd(target_endpose[:6], False)
|
||||
self.right_rmarm.rm_movep_canfd(target_endpose[6:], False)
|
||||
|
||||
def write_dual_gripper(self, left_gripper: int, right_gripper: int):
|
||||
try:
|
||||
self.left_rmarm.rm_set_gripper_position(left_gripper, False, 2)
|
||||
self.right_rmarm.rm_set_gripper_position(right_gripper, False, 2)
|
||||
finally:
|
||||
self.gripper_busy = False
|
||||
|
||||
def _execute_gripper_thread(self, left_gripper: int, right_gripper: int):
|
||||
if not getattr(self, 'gripper_busy'):
|
||||
setattr(self, 'gripper_busy', True)
|
||||
|
||||
threading.Thread(
|
||||
target=self.write_dual_gripper,
|
||||
args=(left_gripper, right_gripper),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
def read_current_arm_joint_state(self):
|
||||
return self.left_rmarm.rm_get_current_arm_state()[1]['joint'] + self.right_rmarm.rm_get_current_arm_state()[1]['joint']
|
||||
|
||||
def read_current_arm_endpose_state(self):
|
||||
return self.left_rmarm.rm_get_current_arm_state()[1]['pose'] + self.right_rmarm.rm_get_current_arm_state()[1]['pose']
|
||||
|
||||
########################## lerobot function ##########################
|
||||
|
||||
def connect(self, enable: bool = True) -> bool:
|
||||
"""使能机械臂并检测使能状态"""
|
||||
with self._timeout_context() as is_timeout_valid:
|
||||
while is_timeout_valid():
|
||||
try:
|
||||
if enable:
|
||||
left_ret = self.left_rmarm.rm_get_current_arm_state()
|
||||
right_ret = self.right_rmarm.rm_get_current_arm_state()
|
||||
if left_ret[0] == 0 and right_ret[0] == 0:
|
||||
print("机械臂使能成功")
|
||||
return True
|
||||
else:
|
||||
RoboticArm.rm_destory()
|
||||
print("机械臂断开连接")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"连接异常: {e}")
|
||||
time.sleep(1)
|
||||
print("连接超时")
|
||||
return False
|
||||
|
||||
def set_calibration(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def revert_calibration(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_calibration(self):
|
||||
"""
|
||||
移动到初始位置
|
||||
"""
|
||||
self.write(target_joint=self.init_joint_position)
|
||||
|
||||
def write(self, target_joint: list):
|
||||
"""写入关节位置"""
|
||||
self._validate_joint_count(target_joint, self.joint_count)
|
||||
|
||||
left_joints = target_joint[:self.left_offset]
|
||||
left_gripper = target_joint[self.left_offset]
|
||||
right_joints = target_joint[self.left_offset+1:-1]
|
||||
right_gripper = target_joint[-1]
|
||||
|
||||
self.left_rmarm.rm_movej_canfd(left_joints, follow=False)
|
||||
# self.left_rmarm.rm_movej_follow(left_joints)
|
||||
# self.left_rmarm.rm_set_gripper_position(left_gripper, block=False, timeout=2)
|
||||
self.right_rmarm.rm_movej_canfd(right_joints, follow=False)
|
||||
# self.right_rmarm.rm_movej_follow(right_joints)
|
||||
# self.right_rmarm.rm_set_gripper_position(right_gripper, block=False, timeout=2)
|
||||
self._execute_gripper_thread(left_gripper, right_gripper)
|
||||
|
||||
|
||||
def read(self) -> Dict:
|
||||
"""读取机械臂状态 - 直接从缓存获取"""
|
||||
with self._cache_lock:
|
||||
return self._state_cache.copy()
|
||||
|
||||
def safe_disconnect(self):
|
||||
"""安全断开连接"""
|
||||
try:
|
||||
self.write(target_joint=self.safe_disable_position)
|
||||
time.sleep(2) # 等待移动完成
|
||||
except Exception as e:
|
||||
print(f"移动到安全位置失败: {e}")
|
||||
finally:
|
||||
RoboticArm.rm_destory()
|
||||
|
||||
########################## lerobot function ##########################
|
||||
@@ -44,6 +44,15 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "realman":
|
||||
from lerobot.common.robot_devices.motors.realman import RealmanMotorsBus
|
||||
|
||||
motors_buses[key] = RealmanMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "realman_dual":
|
||||
from lerobot.common.robot_devices.motors.realman_dual import RealmanDualMotorsBus
|
||||
|
||||
motors_buses[key] = RealmanDualMotorsBus(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
@@ -65,3 +74,7 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
|
||||
def get_motor_names(arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
@@ -27,6 +27,8 @@ from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
RealmanMotorsBusConfig,
|
||||
RealmanDualMotorsBusConfig
|
||||
)
|
||||
|
||||
|
||||
@@ -674,3 +676,147 @@ class LeKiwiRobotConfig(RobotConfig):
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("realman")
|
||||
@dataclass
|
||||
class RealmanRobotConfig(RobotConfig):
|
||||
inference_time: bool = False
|
||||
max_gripper: int = 990
|
||||
min_gripper: int = 10
|
||||
servo_config_file: str = "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||
|
||||
|
||||
left_follower_arm: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": RealmanMotorsBusConfig(
|
||||
ip = "192.168.3.18",
|
||||
port = 8080,
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"joint_1": [1, "realman"],
|
||||
"joint_2": [2, "realman"],
|
||||
"joint_3": [3, "realman"],
|
||||
"joint_4": [4, "realman"],
|
||||
"joint_5": [5, "realman"],
|
||||
"joint_6": [6, "realman"],
|
||||
"gripper": [7, "realman"],
|
||||
},
|
||||
init_joint = {'joint': [-90, 90, 90, 90, 90, -90, 10]}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
# "one": OpenCVCameraConfig(
|
||||
# camera_index=4,
|
||||
# fps=30,
|
||||
# width=640,
|
||||
# height=480,
|
||||
# ),
|
||||
"left": IntelRealSenseCameraConfig(
|
||||
serial_number="153122077516",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
# "right": IntelRealSenseCameraConfig(
|
||||
# serial_number="405622075165",
|
||||
# fps=30,
|
||||
# width=640,
|
||||
# height=480,
|
||||
# use_depth=False
|
||||
# ),
|
||||
"front": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072751",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
"high": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072193",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("realman_dual")
|
||||
@dataclass
|
||||
class RealmanDualRobotConfig(RobotConfig):
|
||||
inference_time: bool = False
|
||||
max_gripper: int = 990
|
||||
min_gripper: int = 10
|
||||
servo_config_file: str = "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_dual.yaml"
|
||||
left_end_control_guid: str = '0300b14bff1100003708000010010000'
|
||||
right_end_control_guid: str = '0300509d5e040000120b000009050000'
|
||||
|
||||
follower_arm: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": RealmanDualMotorsBusConfig(
|
||||
axis= {'left_joint': 6, 'left_gripper': 1, 'right_joint': 6, 'right_gripper': 1},
|
||||
left_ip = "192.168.3.18",
|
||||
left_port = 8080,
|
||||
right_ip = "192.168.3.19",
|
||||
right_port = 8080,
|
||||
init_joint = {'joint': [-170, 90, 0, 90, 120, 0, 10, 170, 90, 0, -90, 120, 0, 10]},
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"left_joint_1": [1, "realman"],
|
||||
"left_joint_2": [2, "realman"],
|
||||
"left_joint_3": [3, "realman"],
|
||||
"left_joint_4": [4, "realman"],
|
||||
"left_joint_5": [5, "realman"],
|
||||
"left_joint_6": [6, "realman"],
|
||||
"left_gripper": [7, "realman"],
|
||||
"right_joint_1": [8, "realman"],
|
||||
"right_joint_2": [9, "realman"],
|
||||
"right_joint_3": [10, "realman"],
|
||||
"right_joint_4": [11, "realman"],
|
||||
"right_joint_5": [12, "realman"],
|
||||
"right_joint_6": [13, "realman"],
|
||||
"right_gripper": [14, "realman"]
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": IntelRealSenseCameraConfig(
|
||||
serial_number="153122077516",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
"right": IntelRealSenseCameraConfig(
|
||||
serial_number="405622075165",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
"front": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072751",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
"high": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072193",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
}
|
||||
)
|
||||
292
lerobot/common/robot_devices/robots/realman.py
Normal file
292
lerobot/common/robot_devices/robots/realman.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Teleoperation Realman with a PS5 controller and
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque
|
||||
from lerobot.common.robot_devices.teleop.realman_single import HybridController
|
||||
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.robots.configs import RealmanRobotConfig
|
||||
|
||||
|
||||
class RealmanRobot:
|
||||
def __init__(self, config: RealmanRobotConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = RealmanRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.robot_type = self.config.type
|
||||
self.inference_time = self.config.inference_time # if it is inference time
|
||||
|
||||
# build cameras
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
||||
# build realman motors
|
||||
self.piper_motors = make_motors_buses_from_configs(self.config.left_follower_arm)
|
||||
self.arm = self.piper_motors['main']
|
||||
|
||||
# build init teleop info
|
||||
self.init_info = {
|
||||
'init_joint': self.arm.init_joint_position,
|
||||
'init_pose': self.arm.init_pose,
|
||||
'max_gripper': config.max_gripper,
|
||||
'min_gripper': config.min_gripper,
|
||||
'servo_config_file': config.servo_config_file
|
||||
}
|
||||
|
||||
# build state-action cache
|
||||
self.joint_queue = deque(maxlen=2)
|
||||
self.last_endpose = self.arm.init_pose
|
||||
|
||||
# build gamepad teleop
|
||||
if not self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
else:
|
||||
self.teleop = None
|
||||
|
||||
self.logs = {}
|
||||
self.is_connected = False
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = get_motor_names(self.piper_motors)
|
||||
state_names = get_motor_names(self.piper_motors)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect RealmanArm and cameras"""
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"RealmanArm is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
# connect RealmanArm
|
||||
self.arm.connect(enable=True)
|
||||
print("RealmanArm conneted")
|
||||
|
||||
# connect cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
print(f"camera {name} conneted")
|
||||
|
||||
print("All connected")
|
||||
self.is_connected = True
|
||||
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""move to home position, disenable piper and cameras"""
|
||||
# move piper to home position, disable
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
# disconnect piper
|
||||
self.arm.safe_disconnect()
|
||||
print("RealmanArm disable after 5 seconds")
|
||||
time.sleep(5)
|
||||
self.arm.connect(enable=False)
|
||||
|
||||
# disconnect cameras
|
||||
if len(self.cameras) > 0:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
def run_calibration(self):
|
||||
"""move piper to the home position"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
self.arm.apply_calibration()
|
||||
if not self.inference_time:
|
||||
self.teleop.reset()
|
||||
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None and self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
|
||||
# read target pose state as
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.arm.read() # read current joint position from robot
|
||||
action = self.teleop.get_action() # target joint position and pose end pos from gamepad
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if action['control_mode'] == 'joint':
|
||||
# 关节控制模式(主模式)
|
||||
current_pose = self.arm.read_current_arm_endpose_state()
|
||||
self.teleop.update_endpose_state(current_pose)
|
||||
|
||||
target_joints = action['joint_angles'][:-1]
|
||||
self.arm.write_gripper(action['gripper'])
|
||||
print(action['gripper'])
|
||||
if action['master_controller_status']['infrared'] == 1:
|
||||
if action['master_controller_status']['button'] == 1:
|
||||
self.arm.write_joint_canfd(target_joints)
|
||||
else:
|
||||
self.arm.write_joint_slow(target_joints)
|
||||
|
||||
# do action
|
||||
before_write_t = time.perf_counter()
|
||||
self.joint_queue.append(list(self.arm.read().values()))
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
else:
|
||||
target_pose = list(action['end_pose'])
|
||||
# do action
|
||||
before_write_t = time.perf_counter()
|
||||
if self.last_endpose != target_pose:
|
||||
self.arm.write_endpose_canfd(target_pose)
|
||||
self.last_endpose = target_pose
|
||||
self.arm.write_gripper(action['gripper'])
|
||||
|
||||
target_joints = self.arm.read_current_arm_joint_state()
|
||||
self.joint_queue.append(list(self.arm.read().values()))
|
||||
self.teleop.update_joint_state(target_joints)
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
|
||||
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Write the predicted actions from policy to the motors"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# send to motors, torch to list
|
||||
target_joints = action.tolist()
|
||||
len_joint = len(target_joints) - 1
|
||||
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
|
||||
target_joints[-1] = int(target_joints[-1]*500 + 500)
|
||||
self.arm.write(target_joints)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
"""capture current images and joint positions"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# read current joint positions
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.arm.read() # 6 joints + 1 gripper
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
|
||||
|
||||
# read images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
return obs_dict
|
||||
|
||||
def teleop_safety_stop(self):
|
||||
""" move to home position after record one episode """
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
robot = RealmanRobot()
|
||||
robot.connect()
|
||||
# robot.run_calibration()
|
||||
while True:
|
||||
robot.teleop_step(record_data=True)
|
||||
|
||||
robot.capture_observation()
|
||||
dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
|
||||
robot.send_action(dummy_action)
|
||||
robot.disconnect()
|
||||
print('ok')
|
||||
319
lerobot/common/robot_devices/robots/realman_dual.py
Normal file
319
lerobot/common/robot_devices/robots/realman_dual.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Teleoperation Realman with a PS5 controller and
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque
|
||||
from lerobot.common.robot_devices.teleop.realman_aloha_dual import HybridController
|
||||
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.robots.configs import RealmanDualRobotConfig
|
||||
|
||||
|
||||
|
||||
class RealmanDualRobot:
|
||||
def __init__(self, config: RealmanDualRobotConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = RealmanDualRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.robot_type = self.config.type
|
||||
self.inference_time = self.config.inference_time # if it is inference time
|
||||
|
||||
# build cameras
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
||||
# build realman motors
|
||||
self.piper_motors = make_motors_buses_from_configs(self.config.follower_arm)
|
||||
self.arm = self.piper_motors['main']
|
||||
|
||||
# 初始化遥操作
|
||||
self._initialize_teleop()
|
||||
# init state
|
||||
self._initialize_state()
|
||||
|
||||
def _initialize_teleop(self):
|
||||
"""初始化遥操作"""
|
||||
self.init_info = {
|
||||
'init_joint': self.arm.init_joint_position,
|
||||
'init_pose': self.arm.init_pose,
|
||||
'max_gripper': self.config.max_gripper,
|
||||
'min_gripper': self.config.min_gripper,
|
||||
'servo_config_file': self.config.servo_config_file,
|
||||
'end_control_info': {'left': self.config.left_end_control_guid , 'right': self.config.right_end_control_guid}
|
||||
}
|
||||
|
||||
if not self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
else:
|
||||
self.teleop = None
|
||||
|
||||
def _initialize_state(self):
|
||||
"""初始化状态"""
|
||||
self.joint_queue = deque(maxlen=2)
|
||||
self.last_endpose = self.arm.init_pose
|
||||
self.logs = {}
|
||||
self.is_connected = False
|
||||
|
||||
def _read_robot_state(self) -> dict:
|
||||
"""读取机器人状态"""
|
||||
before_read_t = time.perf_counter()
|
||||
from copy import deepcopy
|
||||
state = deepcopy(self.arm.read())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
return state
|
||||
|
||||
def _execute_action(self, action: dict, state: dict):
|
||||
"""执行动作"""
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
if action['control_mode'] == 'joint':
|
||||
# self.arm.write_action(action, state)
|
||||
pass
|
||||
else:
|
||||
if list(action['pose'].values()) != list(state['pose'].values()):
|
||||
pose = list(action['pose'].values())
|
||||
self.arm.write_endpose_canfd(pose)
|
||||
|
||||
elif list(action['joint'].values()) != list(state['joint'].values()):
|
||||
target_joint = list(action['joint'].values())
|
||||
self.arm.write(target_joint)
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
def _prepare_record_data(self) -> Tuple[Dict, Dict]:
|
||||
"""准备记录数据"""
|
||||
if len(self.joint_queue) < 2:
|
||||
return {}, {}
|
||||
|
||||
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
|
||||
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
|
||||
# 捕获图像
|
||||
images = self._capture_images()
|
||||
# 构建输出字典
|
||||
obs_dict = {
|
||||
"observation.state": state,
|
||||
**{f"observation.images.{name}": img for name, img in images.items()}
|
||||
}
|
||||
action_dict = {"action": action}
|
||||
return obs_dict, action_dict
|
||||
|
||||
def _update_state_queue(self):
|
||||
"""更新状态队列"""
|
||||
current_state = self.arm.read()['joint']
|
||||
current_state_lst = []
|
||||
for data in current_state:
|
||||
if "joint" in data:
|
||||
current_state_lst.append(current_state[data] / 180)
|
||||
elif "gripper" in data:
|
||||
current_state_lst.append((current_state[data]-500)/500)
|
||||
self.joint_queue.append(current_state_lst)
|
||||
|
||||
def _capture_images(self) -> Dict[str, torch.Tensor]:
|
||||
"""捕获图像"""
|
||||
images = {}
|
||||
for name, camera in self.cameras.items():
|
||||
before_camread_t = time.perf_counter()
|
||||
image = camera.async_read()
|
||||
images[name] = torch.from_numpy(image)
|
||||
|
||||
self.logs[f"read_camera_{name}_dt_s"] = camera.logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
return images
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = get_motor_names(self.piper_motors)
|
||||
state_names = get_motor_names(self.piper_motors)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect RealmanArm and cameras"""
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"RealmanArm is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
# connect RealmanArm
|
||||
self.arm.connect(enable=True)
|
||||
print("RealmanArm conneted")
|
||||
|
||||
# connect cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
print(f"camera {name} conneted")
|
||||
|
||||
print("All connected")
|
||||
self.is_connected = True
|
||||
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""move to home position, disenable piper and cameras"""
|
||||
# move piper to home position, disable
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
# disconnect piper
|
||||
self.arm.safe_disconnect()
|
||||
print("RealmanArm disable after 5 seconds")
|
||||
time.sleep(5)
|
||||
self.arm.connect(enable=False)
|
||||
|
||||
# disconnect cameras
|
||||
if len(self.cameras) > 0:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
def run_calibration(self):
|
||||
"""move piper to the home position"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
self.arm.apply_calibration()
|
||||
if not self.inference_time:
|
||||
self.teleop.reset()
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None and self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
|
||||
try:
|
||||
# 读取当前状态
|
||||
state = self._read_robot_state()
|
||||
# 获取动作
|
||||
action = self.teleop.get_action(state)
|
||||
self._execute_action(action, state)
|
||||
# 更新状态队列
|
||||
self._update_state_queue()
|
||||
time.sleep(0.019) # 50HZ
|
||||
|
||||
if record_data:
|
||||
data = self._prepare_record_data()
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"遥操作步骤失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Write the predicted actions from policy to the motors"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# send to motors, torch to list
|
||||
target_joints = action.tolist()
|
||||
len_joint = len(target_joints) - 1
|
||||
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
|
||||
target_joints[-1] = int(target_joints[-1]*500 + 500)
|
||||
self.arm.write(target_joints)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
"""capture current images and joint positions"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# read current joint positions
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.arm.read() # 6 joints + 1 gripper
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
|
||||
|
||||
# read images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
return obs_dict
|
||||
|
||||
def teleop_safety_stop(self):
|
||||
""" move to home position after record one episode """
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
robot = RealmanDualRobot()
|
||||
robot.connect()
|
||||
# robot.run_calibration()
|
||||
while True:
|
||||
robot.teleop_step(record_data=True)
|
||||
|
||||
# robot.capture_observation()
|
||||
# dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
|
||||
# robot.send_action(dummy_action)
|
||||
# robot.disconnect()
|
||||
# print('ok')
|
||||
@@ -25,6 +25,8 @@ from lerobot.common.robot_devices.robots.configs import (
|
||||
So100RobotConfig,
|
||||
So101RobotConfig,
|
||||
StretchRobotConfig,
|
||||
RealmanRobotConfig,
|
||||
RealmanDualRobotConfig
|
||||
)
|
||||
|
||||
|
||||
@@ -65,6 +67,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
elif robot_type == 'realman':
|
||||
return RealmanRobotConfig(**kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
@@ -78,6 +83,17 @@ def make_robot_from_config(config: RobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
|
||||
return MobileManipulator(config)
|
||||
|
||||
elif isinstance(config, RealmanRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.realman import RealmanRobot
|
||||
|
||||
return RealmanRobot(config)
|
||||
|
||||
elif isinstance(config, RealmanDualRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.realman_dual import RealmanDualRobot
|
||||
|
||||
return RealmanDualRobot(config)
|
||||
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
|
||||
18
lerobot/common/robot_devices/teleop/find_gamepad.py
Normal file
18
lerobot/common/robot_devices/teleop/find_gamepad.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import pygame
|
||||
|
||||
def find_controller_index():
|
||||
# 获取所有 pygame 控制器的设备路径
|
||||
pygame_joysticks = {}
|
||||
for i in range(pygame.joystick.get_count()):
|
||||
joystick = pygame.joystick.Joystick(i)
|
||||
joystick.init()
|
||||
pygame_joysticks[joystick.get_guid()] = {
|
||||
'index': i,
|
||||
'device_name': joystick.get_name()
|
||||
}
|
||||
|
||||
return pygame_joysticks
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(find_controller_index())
|
||||
421
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
421
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import pygame
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict
|
||||
from dataclasses import dataclass
|
||||
from .find_gamepad import find_controller_index
|
||||
from .servo_server import ServoArmServer
|
||||
|
||||
|
||||
class RealmanAlohaMaster:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self._initialize_master_arm()
|
||||
|
||||
def _initialize_master_arm(self):
|
||||
"""初始化主控臂"""
|
||||
try:
|
||||
self.master_dual_arm = ServoArmServer(self.config.config_file)
|
||||
except Exception as e:
|
||||
logging.error(f"初始化主控臂失败: {e}")
|
||||
raise
|
||||
|
||||
def get_action(self) -> Dict:
|
||||
"""获取控制动作"""
|
||||
try:
|
||||
master_joint_actions = self.master_dual_arm.get_joint_data()
|
||||
return self._format_action(master_joint_actions)
|
||||
except Exception as e:
|
||||
logging.error(f"获取动作失败: {e}")
|
||||
|
||||
def _format_action(self, master_joint_actions: dict) -> dict:
|
||||
"""格式化动作数据"""
|
||||
master_controller_status = {
|
||||
'left': master_joint_actions['left_controller_status'],
|
||||
'right': master_joint_actions['right_controller_status']
|
||||
}
|
||||
|
||||
return {
|
||||
'control_mode': 'joint',
|
||||
'master_joint_actions': master_joint_actions['dual_joint_actions'],
|
||||
'left_joint_actions': master_joint_actions['left_joint_actions'][:-1],
|
||||
'right_joint_actions': master_joint_actions['right_joint_actions'][:-1],
|
||||
'left_gripper_actions': master_joint_actions['left_joint_actions'][-1], # 修复bug
|
||||
'right_gripper_actions': master_joint_actions['right_joint_actions'][-1],
|
||||
'master_controller_status': master_controller_status
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
"""停止控制器"""
|
||||
try:
|
||||
if hasattr(self, 'master_dual_arm') and self.master_dual_arm:
|
||||
self.master_dual_arm.shutdown()
|
||||
print("混合控制器已退出")
|
||||
except Exception as e:
|
||||
logging.error(f"停止控制器失败: {e}")
|
||||
|
||||
|
||||
|
||||
class DummyEndposeMaster:
|
||||
def __init__(self, config):
|
||||
# 初始化pygame
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
# 获取所有 USB 游戏控制器的信息
|
||||
self.joysticks = find_controller_index()
|
||||
print(self.joysticks)
|
||||
self.control_info = config.end_control_info
|
||||
left_stick = self._init_stick('left')
|
||||
right_stick = self._init_stick('right')
|
||||
self.controllers = [left_stick, right_stick]
|
||||
|
||||
def _init_stick(self, arm_name:str):
|
||||
stick_info = {}
|
||||
stick_info['index'] = self.joysticks[self.control_info[arm_name]]['index']
|
||||
stick_info['guid'] = self.control_info[arm_name]
|
||||
stick_info['name'] = f'{arm_name}'
|
||||
device_name = self.joysticks[self.control_info[arm_name]]['device_name']
|
||||
stick = XboxStick(stick_info) if "Xbox" in device_name else FlightStick(stick_info)
|
||||
stick.start_polling()
|
||||
return stick
|
||||
|
||||
def get_action(self, state) -> Dict:
|
||||
from copy import deepcopy
|
||||
|
||||
new_state = deepcopy(state)
|
||||
gamepad_action = {}
|
||||
xyz = []
|
||||
rxryrz = []
|
||||
gripper = []
|
||||
"""获取控制动作"""
|
||||
try:
|
||||
for i, controller in enumerate(self.controllers):
|
||||
# states = controller.get_raw_states()
|
||||
gamepad_action.update(controller.get_control_signal(controller.name))
|
||||
xyz += [f"{controller.name}_x", f"{controller.name}_y", f"{controller.name}_z"]
|
||||
rxryrz += [f"{controller.name}_joint_4", f"{controller.name}_joint_5", f"{controller.name}_joint_6"]
|
||||
gripper += [f"{controller.name}_gripper"]
|
||||
|
||||
for name in xyz:
|
||||
new_state['pose'][name] += (gamepad_action[name] * gamepad_action['xyz_vel'] * gamepad_action[name.split('_')[0]+'_ratio'])
|
||||
|
||||
for name in gripper:
|
||||
new_state['joint'][name] += int(gamepad_action[name] * gamepad_action['gripper_vel'] * gamepad_action[name.split('_')[0]+'_ratio'])
|
||||
new_state['joint'][name] = min(990, max(0, new_state['joint'][name]))
|
||||
|
||||
for name in rxryrz:
|
||||
new_state['joint'][name] += (gamepad_action[name] * gamepad_action['rxyz_vel'] * gamepad_action[name.split('_')[0]+'_ratio'])
|
||||
|
||||
new_state['control_mode'] = 'endpose'
|
||||
return new_state
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取动作失败: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止控制器"""
|
||||
try:
|
||||
# 停止轮询线程
|
||||
for controller in self.controllers:
|
||||
controller.stop_polling()
|
||||
except Exception as e:
|
||||
logging.error(f"停止控制器失败: {e}")
|
||||
|
||||
|
||||
|
||||
class ControllerBase:
|
||||
def __init__(self, joystick_info: dict):
|
||||
# 初始化手柄对象
|
||||
self.joystick = pygame.joystick.Joystick(joystick_info['index'])
|
||||
self.joystick.init()
|
||||
self.name = joystick_info['name']
|
||||
self.guid = joystick_info['guid']
|
||||
|
||||
# 存储所有控制器状态的字典
|
||||
self.states = {
|
||||
'buttons': [False] * self.joystick.get_numbuttons(), # 按钮状态
|
||||
'axes': [0.0] * self.joystick.get_numaxes(), # 摇杆和轴状态
|
||||
'hats': [(0, 0)] * self.joystick.get_numhats() # 舵状态
|
||||
}
|
||||
|
||||
# deadzone
|
||||
self.deadzone = 0.15
|
||||
# validzone
|
||||
self.validzone = 0.05
|
||||
self.ratio = 1
|
||||
self.gripper_vel = 100
|
||||
self.rxyz_vel = 5
|
||||
self.xyz_vel = 0.02
|
||||
self.scale_up = 2
|
||||
self.scale_down = 10
|
||||
|
||||
# 线程控制标志
|
||||
self.running = False
|
||||
|
||||
def start_polling(self):
|
||||
"""启动线程以轮询控制器状态"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._poll_controller)
|
||||
self.thread.start()
|
||||
|
||||
def stop_polling(self):
|
||||
"""停止线程"""
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.thread.join()
|
||||
|
||||
def _poll_controller(self):
|
||||
"""后台线程函数,用于轮询控制器状态"""
|
||||
while self.running:
|
||||
# 处理pygame事件
|
||||
pygame.event.pump()
|
||||
|
||||
# 获取按钮状态
|
||||
for i in range(self.joystick.get_numbuttons()):
|
||||
self.states['buttons'][i] = self.joystick.get_button(i)
|
||||
|
||||
# 获取摇杆和轴状态(通常范围是 -1.0 到 1.0)
|
||||
for i in range(self.joystick.get_numaxes()):
|
||||
self.states['axes'][i] = self.joystick.get_axis(i)
|
||||
|
||||
# 获取舵状态(通常返回一个元组 (x, y),值范围为 -1, 0, 1)
|
||||
for i in range(self.joystick.get_numhats()):
|
||||
self.states['hats'][i] = self.joystick.get_hat(i)
|
||||
|
||||
# 控制轮询频率
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_raw_states(self):
|
||||
"""获取当前控制器状态"""
|
||||
return self.states
|
||||
|
||||
class FlightStick(ControllerBase):
|
||||
def __init__(self, joystick_info):
|
||||
super().__init__(joystick_info)
|
||||
|
||||
def get_x_control_signal(self):
|
||||
x = 0
|
||||
if self.states['axes'][0] > self.validzone:
|
||||
x = 1
|
||||
elif self.states['axes'][0] < -self.validzone:
|
||||
x = -1
|
||||
return x
|
||||
|
||||
def get_y_control_signal(self):
|
||||
y = 0
|
||||
if self.states['axes'][1] > self.validzone:
|
||||
y = -1
|
||||
elif self.states['axes'][1] < -self.validzone:
|
||||
y = 1
|
||||
return y
|
||||
|
||||
def get_z_control_signal(self):
|
||||
z = 0
|
||||
if self.states['buttons'][0]:
|
||||
z = 1
|
||||
elif self.states['buttons'][1]:
|
||||
z = -1
|
||||
return z
|
||||
|
||||
def get_gripper_control_signal(self):
|
||||
gripper = 0
|
||||
if self.states['buttons'][2] == 1:
|
||||
gripper = 1
|
||||
elif self.states['buttons'][3] == 1:
|
||||
gripper = -1
|
||||
return gripper
|
||||
|
||||
def get_ratio_control_signal(self):
|
||||
ratio = self.ratio
|
||||
if self.states['axes'][2] > 0.8:
|
||||
ratio = self.ratio / self.scale_down
|
||||
elif self.states['axes'][2] < -0.8:
|
||||
ratio = self.ratio * self.scale_up
|
||||
return ratio
|
||||
|
||||
def get_rx_control_signal(self):
|
||||
rx = 0
|
||||
if self.states['hats'][0][0] == -1:
|
||||
rx = 1
|
||||
elif self.states['hats'][0][0] == 1:
|
||||
rx = -1
|
||||
else:
|
||||
rx = 0
|
||||
return rx
|
||||
|
||||
def get_ry_control_signal(self):
|
||||
ry = 0
|
||||
if self.states['hats'][0][1] == 1:
|
||||
ry = -1
|
||||
elif self.states['hats'][0][1] == -1:
|
||||
ry = 1
|
||||
else:
|
||||
ry = 0
|
||||
return ry
|
||||
|
||||
def get_rz_control_signal(self):
|
||||
rz = 0
|
||||
if self.states['axes'][3] < -self.validzone:
|
||||
rz = -1
|
||||
elif self.states['axes'][3] > self.validzone:
|
||||
rz = 1
|
||||
else:
|
||||
rz = 0
|
||||
return rz
|
||||
|
||||
def get_control_signal(self, prefix: str = ""):
|
||||
"""获取所有控制信号"""
|
||||
return {
|
||||
f'{prefix}_x': self.get_x_control_signal(),
|
||||
f'{prefix}_y': self.get_y_control_signal(),
|
||||
f'{prefix}_z': self.get_z_control_signal(),
|
||||
f'{prefix}_joint_4': self.get_rx_control_signal(),
|
||||
f'{prefix}_joint_5': self.get_ry_control_signal(),
|
||||
f'{prefix}_joint_6': self.get_rz_control_signal(),
|
||||
f'{prefix}_gripper': self.get_gripper_control_signal(),
|
||||
f'{prefix}_ratio': self.get_ratio_control_signal(),
|
||||
'gripper_vel': self.gripper_vel,
|
||||
'rxyz_vel': self.rxyz_vel,
|
||||
'xyz_vel': self.xyz_vel
|
||||
}
|
||||
|
||||
|
||||
|
||||
class XboxStick(ControllerBase):
|
||||
def __init__(self, joystick_info: dict):
|
||||
super().__init__(joystick_info)
|
||||
|
||||
def get_x_control_signal(self):
|
||||
"""获取 X 轴控制信号"""
|
||||
x = 0
|
||||
if self.states['hats'][0][0] == -1:
|
||||
x = 1
|
||||
elif self.states['hats'][0][0] == 1:
|
||||
x = -1
|
||||
return x
|
||||
|
||||
def get_y_control_signal(self):
|
||||
"""获取 Y 轴控制信号"""
|
||||
y = 0
|
||||
if self.states['hats'][0][1] == 1:
|
||||
y = -1
|
||||
elif self.states['hats'][0][1] == -1:
|
||||
y = 1
|
||||
return y
|
||||
|
||||
def get_z_control_signal(self):
|
||||
"""获取 Z 轴控制信号"""
|
||||
z = 0
|
||||
if self.states['axes'][4] > self.deadzone: # A 按钮
|
||||
z = -1
|
||||
elif self.states['axes'][4] < -self.deadzone: # B 按钮
|
||||
z = 1
|
||||
return z
|
||||
|
||||
def get_ratio_control_signal(self):
|
||||
"""获取速度控制信号"""
|
||||
ratio = self.ratio
|
||||
if self.states['axes'][2] > 0.8: # LT 按钮
|
||||
ratio = self.ratio * self.scale_up
|
||||
elif self.states['axes'][5] > 0.8: # RT 按钮
|
||||
ratio = self.ratio / self.scale_down
|
||||
return ratio
|
||||
|
||||
def get_gripper_control_signal(self):
|
||||
gripper = 0
|
||||
if self.states['buttons'][0] == 1:
|
||||
gripper = 1
|
||||
elif self.states['buttons'][1] == 1:
|
||||
gripper = -1
|
||||
return gripper
|
||||
|
||||
def get_rx_control_signal(self):
|
||||
"""获取 RX 轴控制信号"""
|
||||
rx = 0
|
||||
if self.states['axes'][0] > self.deadzone: # 左舵
|
||||
rx = -1
|
||||
elif self.states['axes'][0] < -self.deadzone: # 右舵
|
||||
rx = 1
|
||||
return rx
|
||||
|
||||
def get_ry_control_signal(self):
|
||||
"""获取 RY 轴控制信号"""
|
||||
ry = 0
|
||||
if self.states['axes'][1] > self.deadzone: # 上舵
|
||||
ry = 1
|
||||
elif self.states['axes'][1] < -self.deadzone: # 下舵
|
||||
ry = -1
|
||||
return ry
|
||||
|
||||
def get_rz_control_signal(self):
|
||||
"""获取 RZ 轴控制信号"""
|
||||
rz = 0
|
||||
if self.states['buttons'][4] == 1: # 左摇杆
|
||||
rz = 1
|
||||
elif self.states['buttons'][5] == 1: # 右摇杆
|
||||
rz = -1
|
||||
return rz
|
||||
|
||||
def get_control_signal(self, prefix: str = ""):
|
||||
"""获取所有控制信号"""
|
||||
return {
|
||||
f'{prefix}_x': self.get_x_control_signal(),
|
||||
f'{prefix}_y': self.get_y_control_signal(),
|
||||
f'{prefix}_z': self.get_z_control_signal(),
|
||||
f'{prefix}_joint_4': self.get_rx_control_signal(),
|
||||
f'{prefix}_joint_5': self.get_ry_control_signal(),
|
||||
f'{prefix}_joint_6': self.get_rz_control_signal(),
|
||||
f'{prefix}_gripper': self.get_gripper_control_signal(),
|
||||
f'{prefix}_ratio': self.get_ratio_control_signal(),
|
||||
'gripper_vel': self.gripper_vel,
|
||||
'rxyz_vel': self.rxyz_vel,
|
||||
'xyz_vel': self.xyz_vel
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControllerConfig:
|
||||
"""控制器配置"""
|
||||
init_joint: list
|
||||
init_pose: list
|
||||
max_gripper: int
|
||||
min_gripper: int
|
||||
config_file: str
|
||||
end_control_info: dict
|
||||
|
||||
|
||||
def parse_init_info(init_info: dict) -> ControllerConfig:
|
||||
"""解析初始化信息"""
|
||||
return ControllerConfig(
|
||||
init_joint=init_info['init_joint'],
|
||||
init_pose=init_info.get('init_pose', [0]*12),
|
||||
max_gripper=init_info['max_gripper'],
|
||||
min_gripper=init_info['min_gripper'],
|
||||
config_file=init_info['servo_config_file'],
|
||||
end_control_info=init_info['end_control_info']
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
'init_joint': {'joint': [-170, 90, 0, 90, 120, 0, 10, 170, 90, 0, -90, 120, 0, 10]},
|
||||
'init_pose': {},
|
||||
'max_gripper': {},
|
||||
'min_gripper': {},
|
||||
'servo_config_file': {},
|
||||
'end_control_info': {'left': "0300b14bff1100003708000010010000" , 'right': '0300509d5e040000120b000009050000'}
|
||||
}
|
||||
config = parse_init_info(config)
|
||||
endpose_arm = DummyEndposeMaster(config)
|
||||
while True:
|
||||
gamepad_action = {}
|
||||
xyz = []
|
||||
for i, controller in enumerate(endpose_arm.controllers):
|
||||
# states = controller.get_raw_states()
|
||||
gamepad_action.update(controller.get_control_signal(controller.name))
|
||||
xyz += [f"{controller.name}_x", f"{controller.name}_y", f"{controller.name}_z"]
|
||||
time.sleep(1)
|
||||
print(gamepad_action)
|
||||
76
lerobot/common/robot_devices/teleop/realman_aloha_dual.py
Normal file
76
lerobot/common/robot_devices/teleop/realman_aloha_dual.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict
|
||||
from dataclasses import dataclass
|
||||
from .gamepad import RealmanAlohaMaster, DummyEndposeMaster
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControllerConfig:
|
||||
"""控制器配置"""
|
||||
init_joint: list
|
||||
init_pose: list
|
||||
max_gripper: int
|
||||
min_gripper: int
|
||||
config_file: str
|
||||
end_control_info: dict
|
||||
|
||||
|
||||
class HybridController:
|
||||
def __init__(self, init_info):
|
||||
self.config = self._parse_init_info(init_info)
|
||||
self.joint = self.config.init_joint.copy()
|
||||
self.pose = self.config.init_pose.copy()
|
||||
|
||||
self.joint_arm = RealmanAlohaMaster(self.config)
|
||||
self.endpose_arm = DummyEndposeMaster(self.config)
|
||||
|
||||
def _parse_init_info(self, init_info: dict) -> ControllerConfig:
|
||||
"""解析初始化信息"""
|
||||
return ControllerConfig(
|
||||
init_joint=init_info['init_joint'],
|
||||
init_pose=init_info.get('init_pose', [0]*12),
|
||||
max_gripper=init_info['max_gripper'],
|
||||
min_gripper=init_info['min_gripper'],
|
||||
config_file=init_info['servo_config_file'],
|
||||
end_control_info=init_info['end_control_info']
|
||||
)
|
||||
|
||||
def get_action(self, state) -> Dict:
|
||||
"""获取控制动作"""
|
||||
try:
|
||||
endpose_action = self.endpose_arm.get_action(state)
|
||||
return endpose_action
|
||||
# return self.joint_arm.get_action()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取动作失败: {e}")
|
||||
|
||||
def stop(self):
|
||||
self.joint_arm.stop()
|
||||
|
||||
def reset(self):
|
||||
"""重置控制器"""
|
||||
self.joint = self.config.init_joint.copy()
|
||||
self.pose = self.config.init_pose.copy()
|
||||
self.joint_control_mode = True
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
init_info = {
|
||||
'init_joint': [-175, 90, 90, 45, 90, -90, 10, 175, 90, 90, -45, 90, 90, 10],
|
||||
'init_pose': [[-0.0305, 0.125938, 0.13153, 3.141, 0.698, -1.57, -0.030486, -0.11487, 0.144707, 3.141, 0.698, 1.57]],
|
||||
'max_gripper': 990,
|
||||
'min_gripper': 10,
|
||||
'servo_config_file': '/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_dual.yaml',
|
||||
'end_control_info': {'left': '0300b14bff1100003708000010010000', 'right': '030003f05e0400008e02000010010000'}
|
||||
}
|
||||
arm_controller = HybridController(init_info)
|
||||
time.sleep(1)
|
||||
try:
|
||||
while True:
|
||||
print(arm_controller.get_action())
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
arm_controller.stop()
|
||||
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
init_joint: [-90, 90, 90, -90, -90, 90]
|
||||
max_gripper: 990
|
||||
min_gripper: 10
|
||||
servo_config_file: "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||
466
lerobot/common/robot_devices/teleop/realman_single.py
Normal file
466
lerobot/common/robot_devices/teleop/realman_single.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import pygame
|
||||
import threading
|
||||
import time
|
||||
import serial
|
||||
import binascii
|
||||
import logging
|
||||
import yaml
|
||||
from typing import Dict
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
|
||||
|
||||
class ServoArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的串口连接并发送初始数据。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.port = self.config["port"]
|
||||
self.baudrate = self.config["baudrate"]
|
||||
self.joint_hex_data = self.config["joint_hex_data"]
|
||||
self.control_hex_data = self.config["control_hex_data"]
|
||||
self.arm_axis = self.config.get("arm_axis", 7)
|
||||
|
||||
try:
|
||||
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||
self.bytes_to_send = binascii.unhexlify(self.joint_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(1)
|
||||
self.connected = True
|
||||
logging.info(f"串口连接成功: {self.port}")
|
||||
except Exception as e:
|
||||
logging.error(f"串口连接失败: {e}")
|
||||
self.connected = False
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
try:
|
||||
with open(config_file, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
except Exception as e:
|
||||
logging.error(f"配置文件加载失败: {e}")
|
||||
# 返回默认配置
|
||||
return {
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baudrate": 460800,
|
||||
"joint_hex_data": "55 AA 02 00 00 67",
|
||||
"control_hex_data": "55 AA 08 00 00 B9",
|
||||
"arm_axis": 6
|
||||
}
|
||||
|
||||
def _bytes_to_signed_int(self, byte_data):
|
||||
"""将字节数据转换为有符号整数。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
|
||||
Returns:
|
||||
int: 有符号整数。
|
||||
"""
|
||||
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||
|
||||
def _parse_joint_data(self, hex_received):
|
||||
"""解析接收到的十六进制数据并提取关节数据。
|
||||
|
||||
Args:
|
||||
hex_received (str): 接收到的十六进制字符串数据。
|
||||
|
||||
Returns:
|
||||
dict: 解析后的关节数据。
|
||||
"""
|
||||
logging.debug(f"hex_received: {hex_received}")
|
||||
joints = {}
|
||||
for i in range(self.arm_axis):
|
||||
start = 14 + i * 10
|
||||
end = start + 8
|
||||
joint_hex = hex_received[start:end]
|
||||
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||
joints[f"joint_{i+1}"] = joint_value
|
||||
grasp_start = 14 + self.arm_axis*10
|
||||
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||
# 夹爪进行归一化处理
|
||||
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||
|
||||
joints["grasp"] = grasp_value
|
||||
return joints
|
||||
|
||||
def _parse_controller_data(self, hex_received):
|
||||
status = {
|
||||
'infrared': 0,
|
||||
'button': 0
|
||||
}
|
||||
if len(hex_received) == 18:
|
||||
status['infrared'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[12:14]))
|
||||
status['button'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[14:16]))
|
||||
# print(infrared)
|
||||
return status
|
||||
|
||||
def get_joint_actions(self):
|
||||
"""从串口读取数据并解析关节动作。
|
||||
|
||||
Returns:
|
||||
dict: 包含关节数据的字典。
|
||||
"""
|
||||
if not self.connected:
|
||||
return {}
|
||||
|
||||
try:
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(0.02)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
if len(bytes_received) == 0:
|
||||
return {}
|
||||
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
actions = self._parse_joint_data(hex_received)
|
||||
return actions
|
||||
except Exception as e:
|
||||
logging.error(f"读取串口数据错误: {e}")
|
||||
return {}
|
||||
|
||||
def get_controller_status(self):
|
||||
bytes_to_send = binascii.unhexlify(self.control_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(bytes_to_send)
|
||||
time.sleep(0.02)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
# print("control status:", hex_received)
|
||||
status = self._parse_controller_data(hex_received)
|
||||
return status
|
||||
|
||||
def close(self):
|
||||
"""关闭串口连接"""
|
||||
if self.connected and hasattr(self, 'serial_conn'):
|
||||
self.serial_conn.close()
|
||||
self.connected = False
|
||||
logging.info("串口连接已关闭")
|
||||
|
||||
|
||||
class HybridController:
|
||||
def __init__(self, init_info):
|
||||
# 初始化pygame和手柄
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
# 检查是否有连接的手柄
|
||||
if pygame.joystick.get_count() == 0:
|
||||
raise Exception("未检测到手柄")
|
||||
|
||||
# 初始化手柄
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
# 摇杆死区
|
||||
self.deadzone = 0.15
|
||||
# 控制模式: True为关节控制(主模式),False为末端控制
|
||||
self.joint_control_mode = True
|
||||
# 精细控制模式
|
||||
self.fine_control_mode = False
|
||||
|
||||
# 初始化末端姿态和关节角度
|
||||
self.init_joint = init_info['init_joint']
|
||||
self.init_pose = init_info.get('init_pose', [0]*6)
|
||||
self.max_gripper = init_info['max_gripper']
|
||||
self.min_gripper = init_info['min_gripper']
|
||||
servo_config_file = init_info['servo_config_file']
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.tozero = False
|
||||
|
||||
# 主臂关节状态
|
||||
self.master_joint_actions = {}
|
||||
self.master_controller_status = {}
|
||||
self.use_master_arm = False
|
||||
|
||||
# 末端位姿限制
|
||||
self.pose_limits = [
|
||||
(-0.800, 0.800), # X (m)
|
||||
(-0.800, 0.800), # Y (m)
|
||||
(-0.800, 0.800), # Z (m)
|
||||
(-3.14, 3.14), # RX (rad)
|
||||
(-3.14, 3.14), # RY (rad)
|
||||
(-3.14, 3.14) # RZ (rad)
|
||||
]
|
||||
|
||||
# 关节角度限制 (度)
|
||||
self.joint_limits = [
|
||||
(-180, 180), # joint 1
|
||||
(-180, 180), # joint 2
|
||||
(-180, 180), # joint 3
|
||||
(-180, 180), # joint 4
|
||||
(-180, 180), # joint 5
|
||||
(-180, 180) # joint 6
|
||||
]
|
||||
|
||||
# 控制参数
|
||||
self.linear_step = 0.002 # 线性移动步长(m)
|
||||
self.angular_step = 0.01 # 角度步长(rad)
|
||||
|
||||
# 夹爪状态和速度
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
|
||||
# 初始化串口通信(主臂关节状态获取)
|
||||
self.servo_arm = None
|
||||
if servo_config_file:
|
||||
try:
|
||||
self.servo_arm = ServoArm(servo_config_file)
|
||||
self.use_master_arm = True
|
||||
logging.info("串口主臂连接成功,启用主从控制模式")
|
||||
except Exception as e:
|
||||
logging.error(f"串口主臂连接失败: {e}")
|
||||
self.use_master_arm = False
|
||||
|
||||
# 启动更新线程
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self.update_controller)
|
||||
self.thread.start()
|
||||
|
||||
print("混合控制器已启动")
|
||||
print("主控制模式: 关节控制")
|
||||
if self.use_master_arm:
|
||||
print("主从控制: 启用")
|
||||
print("Back按钮: 切换控制模式(关节/末端)")
|
||||
print("L3按钮: 切换精细控制模式")
|
||||
print("Start按钮: 重置到初始位置")
|
||||
|
||||
def _apply_nonlinear_mapping(self, value):
|
||||
"""应用非线性映射以提高控制精度"""
|
||||
sign = 1 if value >= 0 else -1
|
||||
return sign * (abs(value) ** 2)
|
||||
|
||||
def _normalize_angle(self, angle):
|
||||
"""将角度归一化到[-π, π]范围内"""
|
||||
import math
|
||||
while angle > math.pi:
|
||||
angle -= 2 * math.pi
|
||||
while angle < -math.pi:
|
||||
angle += 2 * math.pi
|
||||
return angle
|
||||
|
||||
def update_controller(self):
|
||||
while self.running:
|
||||
try:
|
||||
pygame.event.pump()
|
||||
except Exception as e:
|
||||
print(f"控制器错误: {e}")
|
||||
self.stop()
|
||||
continue
|
||||
|
||||
# 检查控制模式切换 (Back按钮)
|
||||
if self.joystick.get_button(6): # Back按钮
|
||||
self.joint_control_mode = not self.joint_control_mode
|
||||
mode_str = "关节控制" if self.joint_control_mode else "末端位姿控制"
|
||||
print(f"切换到{mode_str}模式")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 检查精细控制模式切换 (L3按钮)
|
||||
if self.joystick.get_button(10): # L3按钮
|
||||
self.fine_control_mode = not self.fine_control_mode
|
||||
print(f"切换到{'精细' if self.fine_control_mode else '普通'}控制模式")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 检查重置按钮 (Start按钮)
|
||||
if self.joystick.get_button(7): # Start按钮
|
||||
print("重置机械臂到初始位置...")
|
||||
# print("init_joint", self.init_joint.copy())
|
||||
self.tozero = True
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
print("机械臂已重置到初始位置")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 从串口获取主臂关节状态
|
||||
if self.servo_arm and self.servo_arm.connected:
|
||||
try:
|
||||
self.master_joint_actions = self.servo_arm.get_joint_actions()
|
||||
self.master_controller_status = self.servo_arm.get_controller_status()
|
||||
if self.master_joint_actions:
|
||||
logging.debug(f"主臂关节状态: {self.master_joint_actions}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取主臂状态错误: {e}")
|
||||
self.master_joint_actions = {}
|
||||
# print(self.master_joint_actions)
|
||||
|
||||
# 根据控制模式更新相应的控制逻辑
|
||||
if self.joint_control_mode:
|
||||
# 关节控制模式下,优先使用主臂数据,Xbox作为辅助
|
||||
self.update_joint_control()
|
||||
else:
|
||||
# 末端控制模式,使用Xbox控制
|
||||
self.update_end_pose()
|
||||
time.sleep(0.02)
|
||||
# print('gripper:', self.gripper)
|
||||
|
||||
def update_joint_control(self):
|
||||
"""更新关节角度控制 - 优先使用主臂数据"""
|
||||
if self.use_master_arm and self.master_joint_actions:
|
||||
# 主从控制模式:直接使用主臂的关节角度
|
||||
try:
|
||||
# 将主臂关节角度映射到从臂
|
||||
for i in range(6): # 假设只有6个关节需要控制
|
||||
joint_key = f"joint_{i+1}"
|
||||
if joint_key in self.master_joint_actions:
|
||||
# 直接使用主臂的关节角度(已经是度数)
|
||||
self.joint[i] = self.master_joint_actions[joint_key]
|
||||
|
||||
# 应用关节限制
|
||||
min_val, max_val = self.joint_limits[i]
|
||||
self.joint[i] = max(min_val, min(max_val, self.joint[i]))
|
||||
|
||||
# print(self.joint)
|
||||
logging.debug(f"主臂关节映射到从臂: {self.joint[:6]}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"主臂数据映射错误: {e}")
|
||||
|
||||
# 如果有主臂夹爪数据,使用主臂夹爪状态
|
||||
if self.use_master_arm and "grasp" in self.master_joint_actions:
|
||||
self.gripper = self.master_joint_actions["grasp"] * 1000
|
||||
self.joint[-1] = self.gripper
|
||||
|
||||
|
||||
def update_end_pose(self):
|
||||
"""更新末端位姿控制"""
|
||||
# 根据控制模式调整步长
|
||||
current_linear_step = self.linear_step * (0.1 if self.fine_control_mode else 1.0)
|
||||
current_angular_step = self.angular_step * (0.1 if self.fine_control_mode else 1.0)
|
||||
|
||||
# 方向键控制XY
|
||||
hat = self.joystick.get_hat(0)
|
||||
hat_up = hat[1] == 1 # Y+
|
||||
hat_down = hat[1] == -1 # Y-
|
||||
hat_left = hat[0] == -1 # X-
|
||||
hat_right = hat[0] == 1 # X+
|
||||
|
||||
# 右摇杆控制Z
|
||||
right_y_raw = -self.joystick.get_axis(4)
|
||||
# 左摇杆控制RZ
|
||||
left_y_raw = -self.joystick.get_axis(1)
|
||||
|
||||
# 应用死区
|
||||
right_y = 0.0 if abs(right_y_raw) < self.deadzone else right_y_raw
|
||||
left_y = 0.0 if abs(left_y_raw) < self.deadzone else left_y_raw
|
||||
|
||||
# 计算各轴速度
|
||||
self.pose_speeds[1] = current_linear_step if hat_up else (-current_linear_step if hat_down else 0.0) # Y
|
||||
self.pose_speeds[0] = -current_linear_step if hat_left else (current_linear_step if hat_right else 0.0) # X
|
||||
|
||||
# 设置Z速度(右摇杆Y轴控制)
|
||||
z_mapping = self._apply_nonlinear_mapping(right_y)
|
||||
self.pose_speeds[2] = z_mapping * current_linear_step # Z
|
||||
|
||||
# L1/R1控制RX旋转
|
||||
LB = self.joystick.get_button(4) # RX-
|
||||
RB = self.joystick.get_button(5) # RX+
|
||||
self.pose_speeds[3] = (-current_angular_step if LB else (current_angular_step if RB else 0.0))
|
||||
|
||||
# △/□控制RY旋转
|
||||
triangle = self.joystick.get_button(2) # RY+
|
||||
square = self.joystick.get_button(3) # RY-
|
||||
self.pose_speeds[4] = (current_angular_step if triangle else (-current_angular_step if square else 0.0))
|
||||
|
||||
# 左摇杆Y轴控制RZ旋转
|
||||
rz_mapping = self._apply_nonlinear_mapping(left_y)
|
||||
self.pose_speeds[5] = rz_mapping * current_angular_step * 2 # RZ
|
||||
|
||||
# 夹爪控制(圈/叉)
|
||||
circle = self.joystick.get_button(1) # 夹爪开
|
||||
cross = self.joystick.get_button(0) # 夹爪关
|
||||
if circle:
|
||||
self.gripper = min(self.max_gripper, self.gripper + self.gripper_speed)
|
||||
elif cross:
|
||||
self.gripper = max(self.min_gripper, self.gripper - self.gripper_speed)
|
||||
|
||||
# 更新末端位姿
|
||||
for i in range(6):
|
||||
self.pose[i] += self.pose_speeds[i]
|
||||
|
||||
# 角度归一化处理
|
||||
for i in range(3, 6):
|
||||
self.pose[i] = self._normalize_angle(self.pose[i])
|
||||
|
||||
def update_joint_state(self, joint):
|
||||
self.joint = joint
|
||||
# self.tozero = False
|
||||
|
||||
def update_endpose_state(self, end_pose):
|
||||
self.pose = end_pose
|
||||
# self.tozero = False
|
||||
|
||||
def update_tozero_state(self, tozero):
|
||||
self.tozero = tozero
|
||||
|
||||
|
||||
def get_action(self) -> Dict:
|
||||
"""获取当前控制命令"""
|
||||
return {
|
||||
'control_mode': 'joint' if self.joint_control_mode else 'end_pose',
|
||||
'use_master_arm': self.use_master_arm,
|
||||
'master_joint_actions': self.master_joint_actions,
|
||||
'master_controller_status': self.master_controller_status,
|
||||
'end_pose': self.pose,
|
||||
'joint_angles': self.joint,
|
||||
'gripper': int(self.gripper),
|
||||
'tozero': self.tozero
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
"""停止控制器"""
|
||||
self.running = False
|
||||
if self.thread.is_alive():
|
||||
self.thread.join()
|
||||
if self.servo_arm:
|
||||
self.servo_arm.close()
|
||||
pygame.quit()
|
||||
print("混合控制器已退出")
|
||||
|
||||
def reset(self):
|
||||
"""重置到初始状态"""
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
print("已重置到初始状态")
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 初始化睿尔曼机械臂
|
||||
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
# 创建机械臂连接
|
||||
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||
print(f"机械臂连接ID: {handle.id}")
|
||||
init_pose = arm.rm_get_current_arm_state()[1]['pose']
|
||||
|
||||
with open('/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/realman_mix.yaml', "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
config['init_pose'] = init_pose
|
||||
arm_controller = HybridController(config)
|
||||
try:
|
||||
while True:
|
||||
print(arm_controller.get_action())
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
arm_controller.stop()
|
||||
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
port: /dev/ttyUSB0
|
||||
right_port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
joint_hex_data: "55 AA 02 00 00 67"
|
||||
control_hex_data: "55 AA 08 00 00 B9"
|
||||
arm_axis: 6
|
||||
6
lerobot/common/robot_devices/teleop/servo_dual.yaml
Normal file
6
lerobot/common/robot_devices/teleop/servo_dual.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
left_port: /dev/ttyUSB0
|
||||
right_port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
joint_hex_data: "55 AA 02 00 00 67"
|
||||
control_hex_data: "55 AA 08 00 00 B9"
|
||||
arm_axis: 6
|
||||
321
lerobot/common/robot_devices/teleop/servo_server.py
Normal file
321
lerobot/common/robot_devices/teleop/servo_server.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import threading
|
||||
import time
|
||||
import serial
|
||||
import binascii
|
||||
import logging
|
||||
import yaml
|
||||
from typing import Dict
|
||||
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
# )
|
||||
|
||||
|
||||
class ServoArmServer:
|
||||
def __init__(self, config_file="servo_dual.yaml"):
|
||||
"""初始化服务器,创建左右机械臂实例"""
|
||||
self.config_file = config_file
|
||||
self.left_servo_arm = None
|
||||
self.right_servo_arm = None
|
||||
self.running = False
|
||||
self.data_lock = threading.Lock()
|
||||
self.latest_data = {
|
||||
'left_joint_actions': {},
|
||||
'right_joint_actions': {},
|
||||
'left_controller_status': {},
|
||||
'right_controller_status': {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
# 初始化机械臂
|
||||
self._initialize_arms()
|
||||
# 启动数据采集线程
|
||||
self._start_data_collection()
|
||||
|
||||
|
||||
def _initialize_arms(self):
|
||||
"""初始化左右机械臂"""
|
||||
try:
|
||||
self.left_servo_arm = ServoArm(self.config_file, "left_port")
|
||||
logging.info("左master机械臂初始化成功")
|
||||
except Exception as e:
|
||||
logging.error(f"左master机械臂初始化失败: {e}")
|
||||
|
||||
try:
|
||||
self.right_servo_arm = ServoArm(self.config_file, "right_port")
|
||||
logging.info("右master机械臂初始化成功")
|
||||
except Exception as e:
|
||||
logging.error(f"右master机械臂初始化失败: {e}")
|
||||
|
||||
def _start_data_collection(self):
|
||||
"""启动数据采集线程"""
|
||||
self.running = True
|
||||
|
||||
# 创建左臂数据采集线程
|
||||
self.left_data_thread = threading.Thread(target=self._left_arm_data_loop)
|
||||
self.left_data_thread.daemon = True
|
||||
self.left_data_thread.start()
|
||||
|
||||
# 创建右臂数据采集线程
|
||||
self.right_data_thread = threading.Thread(target=self._right_arm_data_loop)
|
||||
self.right_data_thread.daemon = True
|
||||
self.right_data_thread.start()
|
||||
|
||||
logging.info("左右机械臂数据采集线程已启动")
|
||||
|
||||
def _left_arm_data_loop(self):
|
||||
"""左机械臂数据采集循环"""
|
||||
while self.running:
|
||||
try:
|
||||
left_actions = {}
|
||||
left_controller_status = {}
|
||||
|
||||
# 获取左机械臂数据
|
||||
if self.left_servo_arm and self.left_servo_arm.connected:
|
||||
left_actions = self.left_servo_arm.get_joint_actions()
|
||||
left_controller_status = self.left_servo_arm.get_controller_status()
|
||||
|
||||
if self._check_val_safety(left_actions) == False:
|
||||
time.sleep(0.02)
|
||||
continue
|
||||
# 更新左机械臂数据
|
||||
with self.data_lock:
|
||||
self.latest_data['left_joint_actions'] = [left_actions[k] for k in left_actions]
|
||||
self.latest_data['left_controller_status'] = left_controller_status
|
||||
# 更新dual_joint_actions
|
||||
if self.latest_data['right_joint_actions']:
|
||||
self.latest_data['dual_joint_actions'] = self.latest_data['left_joint_actions'] + self.latest_data['right_joint_actions']
|
||||
self.latest_data['timestamp'] = time.time()
|
||||
|
||||
time.sleep(0.02) # 50Hz采集频率
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"左机械臂数据采集错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
def _right_arm_data_loop(self):
|
||||
"""右机械臂数据采集循环"""
|
||||
while self.running:
|
||||
try:
|
||||
right_actions = {}
|
||||
right_controller_status = {}
|
||||
|
||||
# 获取右机械臂数据
|
||||
if self.right_servo_arm and self.right_servo_arm.connected:
|
||||
right_actions = self.right_servo_arm.get_joint_actions()
|
||||
right_controller_status = self.right_servo_arm.get_controller_status()
|
||||
|
||||
if self._check_val_safety(right_actions) == False:
|
||||
time.sleep(0.02)
|
||||
continue
|
||||
# 更新右机械臂数据
|
||||
with self.data_lock:
|
||||
self.latest_data['right_joint_actions'] = [right_actions[k] for k in right_actions]
|
||||
self.latest_data['right_controller_status'] = right_controller_status
|
||||
# 更新dual_joint_actions
|
||||
if self.latest_data['left_joint_actions']:
|
||||
self.latest_data['dual_joint_actions'] = self.latest_data['left_joint_actions'] + self.latest_data['right_joint_actions']
|
||||
self.latest_data['timestamp'] = time.time()
|
||||
|
||||
time.sleep(0.02) # 50Hz采集频率
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"右机械臂数据采集错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
def _check_val_safety(self, data: dict):
|
||||
data = [data[k] for k in data]
|
||||
ret = True
|
||||
if len(data) != self.left_servo_arm.arm_axis + 1:
|
||||
ret = False
|
||||
for v in data:
|
||||
if v > 180 or v < -180:
|
||||
ret = False
|
||||
return ret
|
||||
|
||||
# ZeroRPC 服务方法
|
||||
def get_joint_data(self):
|
||||
"""获取最新的关节数据"""
|
||||
with self.data_lock:
|
||||
return self.latest_data.copy()
|
||||
|
||||
def get_left_joint_actions(self):
|
||||
"""获取左机械臂关节数据和控制器状态"""
|
||||
with self.data_lock:
|
||||
return {
|
||||
'data': self.latest_data['left_joint_actions'],
|
||||
'controller_status': self.latest_data['left_controller_status'],
|
||||
'timestamp': self.latest_data['timestamp']
|
||||
}
|
||||
|
||||
def get_right_joint_actions(self):
|
||||
"""获取右机械臂关节数据和控制器状态"""
|
||||
with self.data_lock:
|
||||
return {
|
||||
'data': self.latest_data['right_joint_actions'],
|
||||
'controller_status': self.latest_data['right_controller_status'],
|
||||
'timestamp': self.latest_data['timestamp']
|
||||
}
|
||||
|
||||
def get_connection_status(self):
|
||||
"""获取连接状态"""
|
||||
return {
|
||||
'left_connected': self.left_servo_arm.connected if self.left_servo_arm else False,
|
||||
'right_connected': self.right_servo_arm.connected if self.right_servo_arm else False,
|
||||
'server_running': self.running
|
||||
}
|
||||
|
||||
def ping(self):
|
||||
"""测试连接"""
|
||||
return "pong"
|
||||
|
||||
def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
logging.info("正在关闭服务器...")
|
||||
self.running = False
|
||||
|
||||
if self.left_servo_arm:
|
||||
self.left_servo_arm.close()
|
||||
if self.right_servo_arm:
|
||||
self.right_servo_arm.close()
|
||||
|
||||
return "Server shutdown"
|
||||
|
||||
|
||||
class ServoArm:
|
||||
def __init__(self, config_file="config.yaml", port_name="left_port"):
|
||||
"""初始化机械臂的串口连接并发送初始数据。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.port = self.config[port_name]
|
||||
self.baudrate = self.config["baudrate"]
|
||||
self.joint_hex_data = self.config["joint_hex_data"]
|
||||
self.control_hex_data = self.config["control_hex_data"]
|
||||
self.arm_axis = self.config.get("arm_axis", 7)
|
||||
|
||||
try:
|
||||
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||
self.bytes_to_send = binascii.unhexlify(self.joint_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(1)
|
||||
self.connected = True
|
||||
logging.info(f"串口连接成功: {self.port}")
|
||||
except Exception as e:
|
||||
logging.error(f"串口连接失败: {e}")
|
||||
self.connected = False
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
try:
|
||||
with open(config_file, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
except Exception as e:
|
||||
logging.error(f"配置文件加载失败: {e}")
|
||||
# 返回默认配置
|
||||
return {
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baudrate": 460800,
|
||||
"joint_hex_data": "55 AA 02 00 00 67",
|
||||
"control_hex_data": "55 AA 08 00 00 B9",
|
||||
"arm_axis": 6
|
||||
}
|
||||
|
||||
def _bytes_to_signed_int(self, byte_data):
|
||||
"""将字节数据转换为有符号整数。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
|
||||
Returns:
|
||||
int: 有符号整数。
|
||||
"""
|
||||
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||
|
||||
def _parse_joint_data(self, hex_received):
|
||||
"""解析接收到的十六进制数据并提取关节数据。
|
||||
|
||||
Args:
|
||||
hex_received (str): 接收到的十六进制字符串数据。
|
||||
|
||||
Returns:
|
||||
dict: 解析后的关节数据。
|
||||
"""
|
||||
logging.debug(f"hex_received: {hex_received}")
|
||||
joints = {}
|
||||
for i in range(self.arm_axis):
|
||||
start = 14 + i * 10
|
||||
end = start + 8
|
||||
joint_hex = hex_received[start:end]
|
||||
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||
joints[f"joint_{i+1}"] = joint_value #/ 180
|
||||
grasp_start = 14 + self.arm_axis*10
|
||||
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||
# 夹爪进行归一化处理
|
||||
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||
|
||||
joints["grasp"] = grasp_value
|
||||
return joints
|
||||
|
||||
def _parse_controller_data(self, hex_received):
|
||||
status = {
|
||||
'infrared': 0,
|
||||
'button': 0
|
||||
}
|
||||
if len(hex_received) == 18:
|
||||
status['infrared'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[12:14]))
|
||||
status['button'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[14:16]))
|
||||
# print(infrared)
|
||||
return status
|
||||
|
||||
def get_joint_actions(self):
|
||||
"""从串口读取数据并解析关节动作。
|
||||
|
||||
Returns:
|
||||
dict: 包含关节数据的字典。
|
||||
"""
|
||||
if not self.connected:
|
||||
return {}
|
||||
|
||||
try:
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(0.025)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
if len(bytes_received) == 0:
|
||||
return {}
|
||||
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
actions = self._parse_joint_data(hex_received)
|
||||
return actions
|
||||
except Exception as e:
|
||||
logging.error(f"读取串口数据错误: {e}")
|
||||
return {}
|
||||
|
||||
def get_controller_status(self):
|
||||
bytes_to_send = binascii.unhexlify(self.control_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(bytes_to_send)
|
||||
time.sleep(0.025)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
# print("control status:", hex_received)
|
||||
status = self._parse_controller_data(hex_received)
|
||||
return status
|
||||
|
||||
def close(self):
|
||||
"""关闭串口连接"""
|
||||
if self.connected and hasattr(self, 'serial_conn'):
|
||||
self.serial_conn.close()
|
||||
self.connected = False
|
||||
logging.info("串口连接已关闭")
|
||||
@@ -175,7 +175,8 @@ def say(text, blocking=False):
|
||||
cmd = ["say", text]
|
||||
|
||||
elif system == "Linux":
|
||||
cmd = ["spd-say", text]
|
||||
# cmd = ["spd-say", text]
|
||||
cmd = ["edge-playback", "-t", text]
|
||||
if blocking:
|
||||
cmd.append("--wait")
|
||||
|
||||
|
||||
@@ -273,7 +273,6 @@ def record(
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
@@ -290,6 +289,9 @@ def record(
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
|
||||
156
realman.md
Normal file
156
realman.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Install
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pip install edge-tts
|
||||
sudo apt install mpv -y
|
||||
|
||||
# pip uninstall numpy
|
||||
# pip install numpy==1.26.0
|
||||
# pip install pynput
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
```bash
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
# pip uninstall opencv-python
|
||||
# conda install "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
Install Realman SDK:
|
||||
```bash
|
||||
pip install Robotic_Arm==1.0.4.1
|
||||
pip install pygame
|
||||
```
|
||||
|
||||
# piper集成lerobot
|
||||
见lerobot_piper_tutorial/1. 🤗 LeRobot:新增机械臂的一般流程.pdf
|
||||
|
||||
# Teleoperate
|
||||
```bash
|
||||
cd piper_scripts/
|
||||
bash can_activate.sh can0 1000000
|
||||
|
||||
cd ..
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=piper \
|
||||
--robot.inference_time=false \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
# Record
|
||||
Set dataset root path
|
||||
```bash
|
||||
HF_USER=$PWD/data
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=false \
|
||||
--control.type=record \
|
||||
--control.fps=15 \
|
||||
--control.single_task="move" \
|
||||
--control.repo_id=maic/test \
|
||||
--control.num_episodes=2 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=10 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.play_sounds=true \
|
||||
--control.push_to_hub=false \
|
||||
--control.display_data=true
|
||||
```
|
||||
|
||||
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
|
||||
# visualize
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/test \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
# Replay
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=piper \
|
||||
--robot.inference_time=false \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
# Caution
|
||||
|
||||
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
|
||||
|
||||
|
||||
# Train
|
||||
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/jack \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_jack \
|
||||
--job_name=act_jack \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
# FT smolvla
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--output_dir=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--job_name=smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=false \
|
||||
--steps=200000 \
|
||||
--batch_size=16
|
||||
|
||||
|
||||
# Inference
|
||||
还是使用control_robot.py中的record loop,配置 **--robot.inference_time=true** 可以将手柄移出。
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/act_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/100000/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/eval_smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=60 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/160000/pretrained_model \
|
||||
--control.display_data=true
|
||||
```
|
||||
31
realman_src/dual_arm_connect_test.py
Normal file
31
realman_src/dual_arm_connect_test.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
armright = RoboticArm()
|
||||
|
||||
|
||||
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
|
||||
print("机械臂ID:", lefthandle.id)
|
||||
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
|
||||
print("机械臂ID:", righthandle.id)
|
||||
|
||||
# software_info = armleft.rm_get_arm_software_info()
|
||||
# if software_info[0] == 0:
|
||||
# print("\n================== Arm Software Information ==================")
|
||||
# print("Arm Model: ", software_info[1]['product_version'])
|
||||
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
|
||||
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
|
||||
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
|
||||
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
|
||||
# print("==============================================================\n")
|
||||
# else:
|
||||
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
|
||||
|
||||
print("Left: ", armleft.rm_get_current_arm_state())
|
||||
print("Left: ", armleft.rm_get_arm_all_state())
|
||||
armleft.rm_movej_p()
|
||||
# print("Right: ", armright.rm_get_current_arm_state())
|
||||
|
||||
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
15
realman_src/movep_canfd.py
Normal file
15
realman_src/movep_canfd.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
import time
|
||||
|
||||
# 实例化RoboticArm类
|
||||
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
# 创建机械臂连接,打印连接id
|
||||
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||
print(handle.id)
|
||||
|
||||
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
|
||||
time.sleep(2)
|
||||
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
|
||||
# time.sleep(2)
|
||||
|
||||
arm.rm_delete_robot_arm()
|
||||
0
realman_src/realman_aloha/__init__.py
Normal file
0
realman_src/realman_aloha/__init__.py
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_camera"
|
||||
version = "0.1.0"
|
||||
description = "camera class, currently includes realsense"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
#include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
numpy = ">=2.0.1"
|
||||
opencv-python = ">=4.10.0.84"
|
||||
pyrealsense2 = ">=2.55.1.6486"
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseCamera(metaclass=ABCMeta):
|
||||
"""摄像头基类"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_camera(self):
|
||||
"""启动相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_camera(self):
|
||||
"""停止相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_resolution(self, resolution_width, resolution_height):
|
||||
"""设置相机彩色图像分辨率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_frame_rate(self, fps):
|
||||
"""设置相机彩色图像帧率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_frame(self):
|
||||
"""读取一帧彩色图像和深度图像"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_camera_intrinsics(self):
|
||||
"""获取彩色图像和深度图像的内参"""
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,38 @@
|
||||
from shadow_camera import base_camera
|
||||
import cv2
|
||||
|
||||
class OpenCVCamera(base_camera.BaseCamera):
|
||||
"""基于OpenCV的摄像头类"""
|
||||
|
||||
def __init__(self, device_id=0):
|
||||
"""初始化视频捕获
|
||||
|
||||
参数:
|
||||
device_id: 摄像头设备ID
|
||||
"""
|
||||
self.cap = cv2.VideoCapture(device_id)
|
||||
|
||||
def get_frame(self):
|
||||
"""获取当前帧
|
||||
|
||||
返回:
|
||||
frame: 当前帧的图像数据,取不到时返回None
|
||||
"""
|
||||
ret, frame = self.cap.read()
|
||||
return frame if ret else None
|
||||
|
||||
def get_frame_info(self):
|
||||
"""获取当前帧信息
|
||||
|
||||
返回:
|
||||
dict: 帧信息字典
|
||||
"""
|
||||
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
|
||||
|
||||
return {
|
||||
'width': width,
|
||||
'height': height,
|
||||
'channels': channels
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,280 @@
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
import pyrealsense2 as rs
|
||||
import base_camera
|
||||
|
||||
# 设置日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class RealSenseCamera(base_camera.BaseCamera):
|
||||
"""Intel RealSense相机类"""
|
||||
|
||||
def __init__(self, serial_num=None, is_depth_frame=False):
|
||||
"""
|
||||
初始化相机对象
|
||||
:param serial_num: 相机序列号,默认为None
|
||||
"""
|
||||
super().__init__()
|
||||
self._color_resolution = [640, 480]
|
||||
self._depth_resolution = [640, 480]
|
||||
self._color_frames_rate = 30
|
||||
self._depth_frames_rate = 15
|
||||
self.timestamp = 0
|
||||
self.color_timestamp = 0
|
||||
self.depth_timestamp = 0
|
||||
self._colorizer = rs.colorizer()
|
||||
self._config = rs.config()
|
||||
self.is_depth_frame = is_depth_frame
|
||||
self.camera_on = False
|
||||
self.serial_num = serial_num
|
||||
|
||||
def get_serial_num(self):
|
||||
serial_num = {}
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
serial_num[i] = device.get_info(rs.camera_info.serial_number)
|
||||
|
||||
logging.info(f"Detected serial numbers: {serial_num}")
|
||||
return serial_num
|
||||
|
||||
def _set_config(self):
|
||||
if self.serial_num is not None:
|
||||
logging.info(f"Setting device with serial number: {self.serial_num}")
|
||||
self._config.enable_device(self.serial_num)
|
||||
|
||||
self._config.enable_stream(
|
||||
rs.stream.color,
|
||||
self._color_resolution[0],
|
||||
self._color_resolution[1],
|
||||
rs.format.rgb8,
|
||||
self._color_frames_rate,
|
||||
)
|
||||
if self.is_depth_frame:
|
||||
self._config.enable_stream(
|
||||
rs.stream.depth,
|
||||
self._depth_resolution[0],
|
||||
self._depth_resolution[1],
|
||||
rs.format.z16,
|
||||
self._depth_frames_rate,
|
||||
)
|
||||
|
||||
def start_camera(self):
|
||||
"""
|
||||
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
|
||||
"""
|
||||
self._pipeline = rs.pipeline()
|
||||
if self.is_depth_frame:
|
||||
self.point_cloud = rs.pointcloud()
|
||||
self._align = rs.align(rs.stream.color)
|
||||
self._set_config()
|
||||
|
||||
self.profile = self._pipeline.start(self._config)
|
||||
|
||||
if self.is_depth_frame:
|
||||
self._depth_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.depth)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
|
||||
self._color_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.color)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
self.camera_on = True
|
||||
logging.info("Camera started successfully")
|
||||
logging.info(
|
||||
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
|
||||
)
|
||||
logging.info(
|
||||
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
|
||||
)
|
||||
|
||||
def stop_camera(self):
|
||||
"""
|
||||
停止相机
|
||||
"""
|
||||
self._pipeline.stop()
|
||||
self.camera_on = False
|
||||
logging.info("Camera stopped")
|
||||
|
||||
def set_resolution(self, color_resolution, depth_resolution):
|
||||
self._color_resolution = color_resolution
|
||||
self._depth_resolution = depth_resolution
|
||||
logging.info(
|
||||
"Optional color resolution:"
|
||||
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
|
||||
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
|
||||
)
|
||||
logging.info(
|
||||
"Optional depth resolution:"
|
||||
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
|
||||
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
|
||||
)
|
||||
logging.info(f"Set color resolution to: {color_resolution}")
|
||||
logging.info(f"Set depth resolution to: {depth_resolution}")
|
||||
|
||||
def set_frame_rate(self, color_fps, depth_fps):
|
||||
self._color_frames_rate = color_fps
|
||||
self._depth_frames_rate = depth_fps
|
||||
logging.info("Optional color fps: 6 15 30 60 ")
|
||||
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
|
||||
logging.info(f"Set color FPS to: {color_fps}")
|
||||
logging.info(f"Set depth FPS to: {depth_fps}")
|
||||
|
||||
# TODO: 调节白平衡进行补偿
|
||||
# def set_exposure(self, exposure):
|
||||
|
||||
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
color_image = None
|
||||
depth_image = None
|
||||
colorized_depth = None
|
||||
point_cloud = None
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
if is_color:
|
||||
color_frame = frames.get_color_frame()
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
else:
|
||||
color_image = None
|
||||
|
||||
if is_depth:
|
||||
depth_frame = frames.get_depth_frame()
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
else:
|
||||
depth_image = None
|
||||
|
||||
colorized_depth = (
|
||||
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
point_cloud = (
|
||||
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
|
||||
if is_point_cloud
|
||||
else None
|
||||
)
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.color_timestamp = color_frame.get_timestamp()
|
||||
if self.is_depth_frame:
|
||||
self.depth_timestamp = depth_frame.get_timestamp()
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧对齐的彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
aligned_frames = self._align.process(frames)
|
||||
aligned_color_frame = aligned_frames.get_color_frame()
|
||||
self._aligned_depth_frame = aligned_frames.get_depth_frame()
|
||||
|
||||
color_image = np.asanyarray(aligned_color_frame.get_data())
|
||||
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
|
||||
colorized_depth = (
|
||||
np.asanyarray(
|
||||
self._colorizer.colorize(self._aligned_depth_frame).get_data()
|
||||
)
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
|
||||
if is_point_cloud:
|
||||
points = self.point_cloud.calculate(self._aligned_depth_frame)
|
||||
# 将元组数据转换为 NumPy 数组
|
||||
point_cloud = np.array(
|
||||
[[point[0], point[1], point[2]] for point in points.get_vertices()]
|
||||
)
|
||||
else:
|
||||
point_cloud = None
|
||||
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.timestamp = aligned_color_frame.get_timestamp()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
except Exception as e:
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
# device = self.profile.get_device()
|
||||
# device.hardware_reset()
|
||||
|
||||
def get_camera_intrinsics(self):
|
||||
"""
|
||||
获取彩色图像和深度图像的内参信息
|
||||
:return: 彩色图像和深度图像的内参信息
|
||||
"""
|
||||
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
|
||||
logging.info("Getting camera intrinsics")
|
||||
logging.info(
|
||||
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
|
||||
)
|
||||
return self._color_intrinsics, self._depth_intrinsics
|
||||
|
||||
def get_3d_camera_coordinate(self, depth_pixel, align=False):
|
||||
"""
|
||||
获取深度相机坐标系下的三维坐标
|
||||
:param depth_pixel:深度像素坐标
|
||||
:param align: 是否对齐
|
||||
|
||||
:return: 深度值和相机坐标
|
||||
"""
|
||||
if not hasattr(self, "_aligned_depth_frame"):
|
||||
raise AttributeError(
|
||||
"Aligned depth frame not set. Call read_align_frame() first."
|
||||
)
|
||||
|
||||
distance = self._aligned_depth_frame.get_distance(
|
||||
depth_pixel[0], depth_pixel[1]
|
||||
)
|
||||
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
|
||||
camera_coordinate = rs.rs2_deproject_pixel_to_point(
|
||||
intrinsics, depth_pixel, distance
|
||||
)
|
||||
return distance, camera_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
camera = RealSenseCamera(is_depth_frame=False)
|
||||
camera.get_serial_num()
|
||||
camera.start_camera()
|
||||
# camera.set_frame_rate(60, 60)
|
||||
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
|
||||
camera.stop_camera()
|
||||
logging.info(f"Color image shape: {color_image.shape}")
|
||||
# logging.info(f"Depth image shape: {depth_image.shape}")
|
||||
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
|
||||
# logging.info(f"Point cloud shape: {point_cloud.shape}")
|
||||
logging.info(f"Color timestamp: {camera.timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info(f"Color timestamp: {camera.color_timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info("Test passed")
|
||||
@@ -0,0 +1,101 @@
|
||||
import pyrealsense2 as rs
|
||||
import numpy as np
|
||||
import h5py
|
||||
import time
|
||||
import threading
|
||||
import keyboard # 用于监听键盘输入
|
||||
|
||||
# 全局变量
|
||||
is_recording = False # 标志位,控制录制状态
|
||||
color_images = [] # 存储彩色图像
|
||||
depth_images = [] # 存储深度图像
|
||||
timestamps = [] # 存储时间戳
|
||||
|
||||
# 配置D435相机
|
||||
def configure_camera():
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
|
||||
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
|
||||
pipeline.start(config)
|
||||
return pipeline
|
||||
|
||||
# 监听键盘输入,控制录制状态
|
||||
def listen_for_keyboard():
|
||||
global is_recording
|
||||
while True:
|
||||
if keyboard.is_pressed('s'): # 按下 's' 开始录制
|
||||
is_recording = True
|
||||
print("Recording started.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
|
||||
is_recording = False
|
||||
print("Recording stopped.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
|
||||
print("Exiting program.")
|
||||
exit()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 采集图像数据
|
||||
def capture_frames(pipeline):
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
try:
|
||||
while True:
|
||||
if is_recording:
|
||||
frames = pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
|
||||
if not color_frame or not depth_frame:
|
||||
continue
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = time.time()
|
||||
|
||||
# 将图像转换为numpy数组
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
# 存储数据
|
||||
color_images.append(color_image)
|
||||
depth_images.append(depth_image)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
print(f"Captured frame at {timestamp}")
|
||||
|
||||
else:
|
||||
time.sleep(0.1) # 如果未录制,等待一段时间
|
||||
|
||||
finally:
|
||||
pipeline.stop()
|
||||
|
||||
# 保存为HDF5文件
|
||||
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
|
||||
with h5py.File(filename, "w") as f:
|
||||
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
|
||||
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
|
||||
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
|
||||
print(f"Data saved to {filename}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
|
||||
# 启动键盘监听线程
|
||||
keyboard_thread = threading.Thread(target=listen_for_keyboard)
|
||||
keyboard_thread.daemon = True
|
||||
keyboard_thread.start()
|
||||
|
||||
# 配置相机
|
||||
pipeline = configure_camera()
|
||||
|
||||
# 开始采集图像
|
||||
capture_frames(pipeline)
|
||||
|
||||
# 录制结束后保存数据
|
||||
if color_images and depth_images and timestamps:
|
||||
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from os import path
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera import realsense
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
def test_camera():
|
||||
camera = realsense.RealSenseCamera('241122071186')
|
||||
camera.start_camera()
|
||||
|
||||
while True:
|
||||
# result = camera.read_align_frame()
|
||||
# if result is None:
|
||||
# print('is None')
|
||||
# continue
|
||||
# start_time = time.time()
|
||||
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
print(f"color_image: {color_image.shape}")
|
||||
# print(f"Time: {end_time - start_time}")
|
||||
cv2.imshow("bgr_image", color_image)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
camera.stop_camera()
|
||||
|
||||
|
||||
def test_get_serial_num():
|
||||
camera = realsense.RealSenseCamera()
|
||||
device = camera.get_serial_num()
|
||||
|
||||
|
||||
class CameraCapture:
|
||||
def __init__(self, camera_serial_num=None, save_dir="./save"):
|
||||
self._camera_serial_num = camera_serial_num
|
||||
self._color_save_dir = path.join(save_dir, "color")
|
||||
self._depth_save_dir = path.join(save_dir, "depth")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(self._color_save_dir, exist_ok=True)
|
||||
os.makedirs(self._depth_save_dir, exist_ok=True)
|
||||
|
||||
def get_serial_num(self):
|
||||
self._camera_serial_num = {}
|
||||
camera_names = ["left", "right", "head", "table"]
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
self._camera_serial_num[camera_names[i]] = device.get_info(
|
||||
rs.camera_info.serial_number
|
||||
)
|
||||
print(self._camera_serial_num)
|
||||
|
||||
return self._camera_serial_num
|
||||
|
||||
def start_camera(self):
|
||||
if self._camera_serial_num is None:
|
||||
self.get_serial_num()
|
||||
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
|
||||
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
|
||||
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
|
||||
|
||||
self._camera_left.start_camera()
|
||||
self._camera_right.start_camera()
|
||||
self._camera_head.start_camera()
|
||||
|
||||
def stop_camera(self):
|
||||
self._camera_left.stop_camera()
|
||||
self._camera_right.stop_camera()
|
||||
self._camera_head.stop_camera()
|
||||
|
||||
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
|
||||
color_filename = path.join(
|
||||
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
|
||||
)
|
||||
depth_filename = path.join(
|
||||
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
|
||||
)
|
||||
cv2.imwrite(color_filename, color_image)
|
||||
cv2.imwrite(depth_filename, depth_image)
|
||||
|
||||
def capture_images(self):
|
||||
while True:
|
||||
(
|
||||
color_image_left,
|
||||
depth_image_left,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_left.read_align_frame()
|
||||
(
|
||||
color_image_right,
|
||||
depth_image_right,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_right.read_align_frame()
|
||||
(
|
||||
color_image_head,
|
||||
depth_image_head,
|
||||
_,
|
||||
point_cloud3,
|
||||
) = self._camera_head.read_align_frame()
|
||||
|
||||
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
|
||||
|
||||
timestamp = time.time() * 1000
|
||||
|
||||
cv2.imshow("Camera left", bgr_color_image_left)
|
||||
cv2.imshow("Camera right", bgr_color_image_right)
|
||||
cv2.imshow("Camera head", bgr_color_image_head)
|
||||
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_left, depth_image_left, "left"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_right, depth_image_right, "right"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_head, depth_image_head, "head"
|
||||
# )
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#test_camera()
|
||||
test_get_serial_num()
|
||||
"""
|
||||
输入相机序列号制定左右相机:
|
||||
dict:{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
|
||||
保存路径:
|
||||
str:./save
|
||||
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
|
||||
"""
|
||||
|
||||
# capture = CameraCapture()
|
||||
# capture.get_serial_num()
|
||||
# test_get_serial_num()
|
||||
|
||||
# capture.start_camera()
|
||||
# capture.capture_images()
|
||||
# capture.stop_camera()
|
||||
@@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
|
||||
class TestRealSenseCamera:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_camera(self):
|
||||
self.camera = RealSenseCamera()
|
||||
|
||||
def test_get_serial_num(self):
|
||||
serial_nums = self.camera.get_serial_num()
|
||||
assert isinstance(serial_nums, dict)
|
||||
assert len(serial_nums) > 0
|
||||
|
||||
def test_start_stop_camera(self):
|
||||
self.camera.start_camera()
|
||||
assert self.camera.camera_on is True
|
||||
self.camera.stop_camera()
|
||||
assert self.camera.camera_on is False
|
||||
|
||||
def test_set_resolution(self):
|
||||
color_resolution = [1280, 720]
|
||||
depth_resolution = [1280, 720]
|
||||
self.camera.set_resolution(color_resolution, depth_resolution)
|
||||
assert self.camera._color_resolution == color_resolution
|
||||
assert self.camera._depth_resolution == depth_resolution
|
||||
|
||||
def test_set_frame_rate(self):
|
||||
color_fps = 60
|
||||
depth_fps = 60
|
||||
self.camera.set_frame_rate(color_fps, depth_fps)
|
||||
assert self.camera._color_frames_rate == color_fps
|
||||
assert self.camera._depth_frames_rate == depth_fps
|
||||
|
||||
def test_read_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_read_align_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_align_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_camera_intrinsics(self):
|
||||
self.camera.start_camera()
|
||||
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
|
||||
assert color_intrinsics is not None
|
||||
assert depth_intrinsics is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_3d_camera_coordinate(self):
|
||||
self.camera.start_camera()
|
||||
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
|
||||
self.camera.read_align_frame()
|
||||
depth_pixel = [320, 240]
|
||||
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
|
||||
depth_pixel, align=True
|
||||
)
|
||||
assert distance > 0
|
||||
assert len(camera_coordinate) == 3
|
||||
self.camera.stop_camera()
|
||||
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# ACT: Action Chunking with Transformers
|
||||
|
||||
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
|
||||
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
|
||||
|
||||
#### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||
|
||||
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
|
||||
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
### Updates:
|
||||
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
|
||||
|
||||
|
||||
### Repo Structure
|
||||
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||
- ``policy.py`` An adaptor for ACT policy
|
||||
- ``detr`` Model definitions of ACT, modified from DETR
|
||||
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||
- ``constants.py`` Constants shared across files
|
||||
- ``utils.py`` Utils such as data loading and helper functions
|
||||
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
conda create -n aloha python=3.8.10
|
||||
conda activate aloha
|
||||
pip install torchvision
|
||||
pip install torch
|
||||
pip install pyquaternion
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco==2.3.7
|
||||
pip install dm_control==1.0.14
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
pip install packaging
|
||||
pip install h5py
|
||||
pip install ipython
|
||||
cd act/detr && pip install -e .
|
||||
|
||||
### Example Usages
|
||||
|
||||
To set up a new terminal, run:
|
||||
|
||||
conda activate aloha
|
||||
cd <path to act repo>
|
||||
|
||||
### Simulated experiments
|
||||
|
||||
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
|
||||
To generated 50 episodes of scripted data, run:
|
||||
|
||||
python3 record_sim_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--dataset_dir <data save dir> \
|
||||
--num_episodes 50
|
||||
|
||||
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||
To visualize the episode after it is collected, run
|
||||
|
||||
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||
|
||||
To train ACT:
|
||||
|
||||
# Transfer Cube task
|
||||
python3 imitate_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
|
||||
--num_epochs 2000 --lr 1e-5 \
|
||||
--seed 0
|
||||
|
||||
|
||||
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
|
||||
The success rate should be around 90% for transfer cube, and around 50% for insertion.
|
||||
To enable temporal ensembling, add flag ``--temporal_agg``.
|
||||
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||
|
||||
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
|
||||
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.
|
||||
|
||||
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '215222076892',
|
||||
bottom_camera: '215222076981',
|
||||
left_camera: '152122078151',
|
||||
right_camera: '152122073489',
|
||||
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
|
||||
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
|
||||
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
|
||||
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
|
||||
# checkpoint_name: 'policy_best.ckpt'
|
||||
checkpoint_name: 'policy_9500.ckpt'
|
||||
state_dim: 14
|
||||
save_episode: True
|
||||
num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量
|
||||
real_robot: True
|
||||
policy_class: 'ACT'
|
||||
onscreen_render: False
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
|
||||
episode_len: 300 #episode 的最大长度(时间步数)。
|
||||
task_name: 'aloha_01_11.28'
|
||||
temporal_agg: False #是否使用时间聚合
|
||||
batch_size: 8 #训练期间每批的样本数。
|
||||
seed: 1000 #随机种子。
|
||||
chunk_size: 30 #用于处理序列的块大小
|
||||
eval_every: 1 #每隔 eval_every 步评估一次模型。
|
||||
num_steps: 10000 #训练的总步数。
|
||||
validate_every: 1 #每隔 validate_every 步验证一次模型。
|
||||
save_every: 500 #每隔 save_every 步保存一次检查点。
|
||||
load_pretrain: False #是否加载预训练模型。
|
||||
resume_ckpt_path:
|
||||
name_filter: # TODO
|
||||
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
|
||||
stats_dir:
|
||||
sample_weights:
|
||||
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
|
||||
|
||||
policy_config: {
|
||||
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
|
||||
state_dim: 14, # Dimension of the state
|
||||
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
|
||||
lr_backbone: 1.0e-5,
|
||||
masks: False, # If true, the model masks the non-visible pixels
|
||||
backbone: 'resnet18',
|
||||
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
|
||||
dropout: 0.1, # Dropout applied in the transformer
|
||||
nheads: 8,
|
||||
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
|
||||
enc_layers: 4, # Number of encoding layers in the transformer
|
||||
dec_layers: 7, # Number of decoding layers in the transformer
|
||||
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
|
||||
num_queries: 30,
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
|
||||
vq: False,
|
||||
vq_class: none,
|
||||
vq_dim: 64,
|
||||
action_dim: 14,
|
||||
no_encoder: False,
|
||||
lr: 1.0e-5,
|
||||
weight_decay: 1.0e-4,
|
||||
kl_weight: 10,
|
||||
|
||||
# lr_drop: 200,
|
||||
# clip_max_norm: 0.1,
|
||||
}
|
||||
|
||||
|
||||
|
||||
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import collections
|
||||
import os
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def make_ee_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXEETask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array([
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
])
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_act"
|
||||
version = "0.1.0"
|
||||
description = "Embodied data, ACT and other methods; training and verification function packages"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
wandb = ">=0.18.0"
|
||||
einops = ">=0.8.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Generate demonstration data in simulation.
|
||||
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||
Replace the gripper joint positions with the commanded joint position.
|
||||
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||
Save this episode of data, and continue to next episode of data collection.
|
||||
"""
|
||||
|
||||
task_name = args['task_name']
|
||||
dataset_dir = args['dataset_dir']
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer_cube_scripted':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
elif task_name == 'sim_insertion_scripted':
|
||||
policy_cls = InsertionPolicy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f'{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
policy = policy_cls(inject_noise)
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||
# replace gripper pose with gripper control
|
||||
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||
joint[6] = left_ctrl
|
||||
joint[6+7] = right_ctrl
|
||||
|
||||
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||
|
||||
# clear unused variables
|
||||
del env
|
||||
del episode
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print('Replaying joint commands')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
|
||||
episode_replay = [ts]
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- each_cam_name (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
action (14,) 'float64'
|
||||
"""
|
||||
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
# truncate here to be consistent
|
||||
joint_traj = joint_traj[:-1]
|
||||
episode_replay = episode_replay[:-1]
|
||||
|
||||
# len(joint_traj) i.e. actions: max_timesteps
|
||||
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||
max_timesteps = len(joint_traj)
|
||||
while joint_traj:
|
||||
action = joint_traj.pop(0)
|
||||
ts = episode_replay.pop(0)
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
t0 = time.time()
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||
action = root.create_dataset('action', (max_timesteps, 14))
|
||||
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||
|
||||
print(f'Saved to {dataset_dir}')
|
||||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
def __init__(self, inject_noise=False):
|
||||
self.inject_noise = inject_noise
|
||||
self.step_count = 0
|
||||
self.left_trajectory = None
|
||||
self.right_trajectory = None
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def interpolate(curr_waypoint, next_waypoint, t):
|
||||
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||
curr_xyz = curr_waypoint['xyz']
|
||||
curr_quat = curr_waypoint['quat']
|
||||
curr_grip = curr_waypoint['gripper']
|
||||
next_xyz = next_waypoint['xyz']
|
||||
next_quat = next_waypoint['quat']
|
||||
next_grip = next_waypoint['gripper']
|
||||
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||
return xyz, quat, gripper
|
||||
|
||||
def __call__(self, ts):
|
||||
# generate trajectory at first timestep, then open-loop execution
|
||||
if self.step_count == 0:
|
||||
self.generate_trajectory(ts)
|
||||
|
||||
# obtain left and right waypoints
|
||||
if self.left_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||
next_left_waypoint = self.left_trajectory[0]
|
||||
|
||||
if self.right_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||
next_right_waypoint = self.right_trajectory[0]
|
||||
|
||||
# interpolate between waypoints to obtain current pose and gripper command
|
||||
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||
|
||||
# Inject noise
|
||||
if self.inject_noise:
|
||||
scale = 0.01
|
||||
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||
|
||||
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||
|
||||
self.step_count += 1
|
||||
return np.concatenate([action_left, action_right])
|
||||
|
||||
|
||||
class PickAndTransferPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
box_info = np.array(ts_first.observation['env_state'])
|
||||
box_xyz = box_info[:3]
|
||||
box_quat = box_info[3:]
|
||||
# print(f"Generate trajectory for {box_xyz=}")
|
||||
|
||||
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.25])
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||
]
|
||||
|
||||
|
||||
class InsertionPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||
peg_xyz = peg_info[:3]
|
||||
peg_quat = peg_info[3:]
|
||||
|
||||
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||
socket_xyz = socket_info[:3]
|
||||
socket_quat = socket_info[3:]
|
||||
|
||||
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.15])
|
||||
lift_right = 0.00715
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
|
||||
]
|
||||
|
||||
|
||||
def test_policy(task_name):
|
||||
# example rolling out pick_and_transfer policy
|
||||
onscreen_render = True
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
env = make_ee_sim_env('sim_transfer_cube')
|
||||
elif 'sim_insertion' in task_name:
|
||||
env = make_ee_sim_env('sim_insertion')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for episode_idx in range(2):
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
if episode_return > 0:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'sim_transfer_cube_scripted'
|
||||
test_policy(test_task_name)
|
||||
|
||||
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import collections
|
||||
import matplotlib.pyplot as plt
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
def make_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7:7+6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7+6]
|
||||
|
||||
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14)
|
||||
# arm action
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# gripper action
|
||||
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||
action[6] = normalized_left_pos
|
||||
action[7+6] = normalized_right_pos
|
||||
return action
|
||||
|
||||
def test_sim_teleop():
|
||||
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
|
||||
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||
|
||||
# source of data
|
||||
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_left', init_node=True)
|
||||
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_right', init_node=False)
|
||||
|
||||
# setup the environment
|
||||
env = make_sim_env('sim_transfer_cube')
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
# setup plotting
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
for t in range(1000):
|
||||
action = get_action(master_bot_left, master_bot_right)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sim_teleop()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,575 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import set_seed
|
||||
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
# # 隐藏h5py的警告Creating converter from 7 to 5
|
||||
# logging.getLogger("h5py").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RmActEvaluator:
|
||||
def __init__(self, config, save_episode=True, num_rollouts=50):
|
||||
"""
|
||||
初始化Evaluator类
|
||||
|
||||
Args:
|
||||
config (dict): 配置字典
|
||||
checkpoint_name (str): 检查点名称
|
||||
save_episode (bool): 是否保存每个episode
|
||||
num_rollouts (int): 滚动次数
|
||||
"""
|
||||
self.config = config
|
||||
self._seed = config["seed"]
|
||||
self.robot_env = config["robot_env"]
|
||||
self.checkpoint_dir = config["checkpoint_dir"]
|
||||
self.checkpoint_name = config["checkpoint_name"]
|
||||
self.save_episode = save_episode
|
||||
self.num_rollouts = num_rollouts
|
||||
self.state_dim = config["state_dim"]
|
||||
self.real_robot = config["real_robot"]
|
||||
self.policy_class = config["policy_class"]
|
||||
self.onscreen_render = config["onscreen_render"]
|
||||
self.camera_names = config["camera_names"]
|
||||
self.max_timesteps = config["episode_len"]
|
||||
self.task_name = config["task_name"]
|
||||
self.temporal_agg = config["temporal_agg"]
|
||||
self.onscreen_cam = "angle"
|
||||
self.policy_config = config["policy_config"]
|
||||
self.vq = config["policy_config"]["vq"]
|
||||
# self.actuator_config = config["actuator_config"]
|
||||
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
|
||||
self.stats = None
|
||||
self.env = None
|
||||
self.env_max_reward = 0
|
||||
|
||||
def _make_policy(self, policy_class, policy_config):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
|
||||
Args:
|
||||
policy_class (str): 策略类名称
|
||||
policy_config (dict): 策略配置字典
|
||||
|
||||
Returns:
|
||||
policy: 创建的策略对象
|
||||
"""
|
||||
if policy_class == "ACT":
|
||||
return ACTPolicy(policy_config)
|
||||
elif policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(policy_config)
|
||||
elif policy_class == "Diffusion":
|
||||
return DiffusionPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
|
||||
|
||||
def load_policy_and_stats(self):
|
||||
"""
|
||||
加载策略和统计数据
|
||||
"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
|
||||
logging.info(f"Loading policy from: {checkpoint_path}")
|
||||
self.policy = self._make_policy(self.policy_class, self.policy_config)
|
||||
# 加载模型并设置为评估模式
|
||||
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
||||
self.policy.cuda()
|
||||
self.policy.eval()
|
||||
|
||||
if self.vq:
|
||||
vq_dim = self.config["policy_config"]["vq_dim"]
|
||||
vq_class = self.config["policy_config"]["vq_class"]
|
||||
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
|
||||
latent_model_checkpoint_path = os.path.join(
|
||||
self.checkpoint_dir, "latent_model_last.ckpt"
|
||||
)
|
||||
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
|
||||
self.latent_model.eval()
|
||||
self.latent_model.cuda()
|
||||
logging.info(
|
||||
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Loaded: {checkpoint_path}")
|
||||
|
||||
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "rb") as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def pre_process(self, state_qpos):
|
||||
"""
|
||||
预处理状态位置
|
||||
|
||||
Args:
|
||||
state_qpos (np.array): 状态位置数组
|
||||
|
||||
Returns:
|
||||
np.array: 预处理后的状态位置
|
||||
"""
|
||||
if self.policy_class == "Diffusion":
|
||||
return ((state_qpos + 1) / 2) * (
|
||||
self.stats["action_max"] - self.stats["action_min"]
|
||||
) + self.stats["action_min"]
|
||||
# 标准化处理,均值为 0,标准差为 1
|
||||
|
||||
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
|
||||
|
||||
def post_process(self, action):
|
||||
"""
|
||||
后处理动作
|
||||
|
||||
Args:
|
||||
action (np.array): 动作数组
|
||||
|
||||
Returns:
|
||||
np.array: 后处理后的动作
|
||||
"""
|
||||
# 反标准化处理
|
||||
return action * self.stats["action_std"] + self.stats["action_mean"]
|
||||
|
||||
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Args:
|
||||
timestep (object): 时间步对象
|
||||
camera_names (list): 相机名称列表
|
||||
random_crop_resize (bool): 是否随机裁剪和调整大小
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
|
||||
"""
|
||||
current_images = []
|
||||
for cam_name in camera_names:
|
||||
current_image = rearrange(
|
||||
timestep.observation["images"][cam_name], "h w c -> c h w"
|
||||
)
|
||||
current_images.append(current_image)
|
||||
current_image = np.stack(current_images, axis=0)
|
||||
current_image = (
|
||||
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
|
||||
)
|
||||
|
||||
if random_crop_resize:
|
||||
logging.info("Random crop resize is used!")
|
||||
original_size = current_image.shape[-2:]
|
||||
ratio = 0.95
|
||||
current_image = current_image[
|
||||
...,
|
||||
int(original_size[0] * (1 - ratio) / 2) : int(
|
||||
original_size[0] * (1 + ratio) / 2
|
||||
),
|
||||
int(original_size[1] * (1 - ratio) / 2) : int(
|
||||
original_size[1] * (1 + ratio) / 2
|
||||
),
|
||||
]
|
||||
current_image = current_image.squeeze(0)
|
||||
resize_transform = transforms.Resize(original_size, antialias=True)
|
||||
current_image = resize_transform(current_image)
|
||||
current_image = current_image.unsqueeze(0)
|
||||
|
||||
return current_image
|
||||
|
||||
def load_environment(self):
|
||||
"""
|
||||
加载环境
|
||||
"""
|
||||
if self.real_robot:
|
||||
self.env = DeviceAloha(self.robot_env)
|
||||
self.env_max_reward = 0
|
||||
else:
|
||||
from sim_env import make_sim_env
|
||||
|
||||
self.env = make_sim_env(self.task_name)
|
||||
self.env_max_reward = self.env.task.max_reward
|
||||
|
||||
def get_auto_index(self, checkpoint_dir):
|
||||
max_idx = 1000
|
||||
for i in range(max_idx + 1):
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
|
||||
return i
|
||||
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
|
||||
|
||||
def evaluate(self, checkpoint_name=None):
|
||||
"""
|
||||
评估策略
|
||||
|
||||
Returns:
|
||||
tuple: 成功率和平均回报
|
||||
"""
|
||||
if checkpoint_name is not None:
|
||||
self.checkpoint_name = checkpoint_name
|
||||
set_seed(self._seed) # np与torch的随机种子
|
||||
self.load_policy_and_stats()
|
||||
self.load_environment()
|
||||
|
||||
query_frequency = self.policy_config["num_queries"]
|
||||
|
||||
# 时间聚合时,每个时间步只有1个查询
|
||||
if self.temporal_agg:
|
||||
query_frequency = 1
|
||||
num_queries = self.policy_config["num_queries"]
|
||||
|
||||
# # 真实机器人时,基础延迟为13???
|
||||
# if self.real_robot:
|
||||
# BASE_DELAY = 13
|
||||
# # query_frequency -= BASE_DELAY
|
||||
|
||||
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
|
||||
for rollout_id in range(self.num_rollouts):
|
||||
|
||||
timestep = self.env.reset()
|
||||
|
||||
if self.onscreen_render:
|
||||
# TODO 画图
|
||||
pass
|
||||
if self.temporal_agg:
|
||||
all_time_actions = torch.zeros(
|
||||
[max_timesteps, max_timesteps + num_queries, self.state_dim]
|
||||
).cuda()
|
||||
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
|
||||
rewards = []
|
||||
|
||||
with torch.inference_mode():
|
||||
time_0 = time.time()
|
||||
DT = 1 / 30
|
||||
culmulated_delay = 0
|
||||
for t in range(max_timesteps):
|
||||
time_1 = time.time()
|
||||
if self.onscreen_render:
|
||||
# TODO 显示图像
|
||||
pass
|
||||
# process previous timestep to get qpos and image_list
|
||||
obs = timestep.observation
|
||||
qpos_numpy = np.array(obs["qpos"])
|
||||
qpos_history_raw[t] = qpos_numpy
|
||||
qpos = self.pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
|
||||
logging.info(f"t{t}")
|
||||
|
||||
if t % query_frequency == 0:
|
||||
current_image = self.get_image_torch(
|
||||
timestep,
|
||||
self.camera_names,
|
||||
random_crop_resize=(
|
||||
self.config["policy_class"] == "Diffusion"
|
||||
),
|
||||
)
|
||||
|
||||
if t == 0:
|
||||
# 网络预热
|
||||
for _ in range(10):
|
||||
self.policy(qpos, current_image)
|
||||
logging.info("Network warm up done")
|
||||
|
||||
if self.config["policy_class"] == "ACT":
|
||||
if t % query_frequency == 0:
|
||||
if self.vq:
|
||||
if rollout_id == 0:
|
||||
for _ in range(10):
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
logging.info(
|
||||
torch.nonzero(vq_sample[0])[:, 1]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
all_actions = self.policy(
|
||||
qpos, current_image, vq_sample=vq_sample
|
||||
)
|
||||
else:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
if self.temporal_agg:
|
||||
all_time_actions[[t], t : t + num_queries] = all_actions
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = torch.all(
|
||||
actions_for_curr_step != 0, axis=1
|
||||
)
|
||||
actions_for_curr_step = actions_for_curr_step[
|
||||
actions_populated
|
||||
]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(
|
||||
-k * np.arange(len(actions_for_curr_step))
|
||||
)
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = (
|
||||
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
)
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
else:
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "Diffusion":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "CNNMLP":
|
||||
raw_action = self.policy(qpos, current_image)
|
||||
all_actions = raw_action.unsqueeze(0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
### post-process actions
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = self.post_process(raw_action)
|
||||
|
||||
### step the environment
|
||||
if self.real_robot:
|
||||
logging.info(f" action = {action}")
|
||||
timestep = self.env.step(action)
|
||||
|
||||
rewards.append(timestep.reward)
|
||||
duration = time.time() - time_1
|
||||
sleep_time = max(0, DT - duration)
|
||||
time.sleep(sleep_time)
|
||||
if duration >= DT:
|
||||
culmulated_delay += duration - DT
|
||||
logging.warning(
|
||||
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
|
||||
)
|
||||
|
||||
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
|
||||
plt.close()
|
||||
|
||||
if self.real_robot:
|
||||
log_id = self.get_auto_index(self.checkpoint_dir)
|
||||
np.save(
|
||||
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
|
||||
qpos_history_raw,
|
||||
)
|
||||
plt.figure(figsize=(10, 20))
|
||||
for i in range(self.state_dim):
|
||||
plt.subplot(self.state_dim, 1, i + 1)
|
||||
plt.plot(qpos_history_raw[:, i])
|
||||
if i != self.state_dim - 1:
|
||||
plt.xticks([])
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
|
||||
plt.close()
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards != None])
|
||||
episode_returns.append(episode_return)
|
||||
episode_highest_reward = np.max(rewards)
|
||||
highest_rewards.append(episode_highest_reward)
|
||||
logging.info(
|
||||
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
|
||||
)
|
||||
|
||||
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
|
||||
avg_return = np.mean(episode_returns)
|
||||
summary_str = (
|
||||
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
|
||||
)
|
||||
for r in range(self.env_max_reward + 1):
|
||||
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
|
||||
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
|
||||
|
||||
logging.info(summary_str)
|
||||
|
||||
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
|
||||
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
|
||||
f.write(summary_str)
|
||||
f.write(repr(episode_returns))
|
||||
f.write("\n\n")
|
||||
f.write(repr(highest_rewards))
|
||||
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_axis = aloha_config["arm_axis"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.left_image, _, _, _ = self.camera_left.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.right_image, _, _, _ = self.camera_right.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
return {
|
||||
"cam_high": self.top_image,
|
||||
"cam_low": self.bottom_image,
|
||||
"cam_left": self.left_image,
|
||||
"cam_right": self.right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
|
||||
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
|
||||
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
# config = yaml.safe_load(f)
|
||||
# aloha_config = config["robot_env"]
|
||||
# device = DeviceAloha(aloha_config)
|
||||
# device.reset()
|
||||
# while True:
|
||||
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
# time_step = time.time()
|
||||
# timestep = device.step(init_angle)
|
||||
# logging.info(f"Time: {time.time() - time_step}")
|
||||
# obs = timestep.observation
|
||||
|
||||
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
# logging.info(f"Config: {config}")
|
||||
evaluator = RmActEvaluator(config)
|
||||
success_rate, avg_return = evaluator.evaluate()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
import torch.nn.functional as F
|
||||
from .position_encoding import build_position_encoding
|
||||
from torchvision.models import ResNet18_Weights
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from shadow_act.utils.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
train_backbone: bool,
|
||||
num_channels: int,
|
||||
return_interm_layers: bool,
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool,
|
||||
):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# backbone = getattr(torchvision.models, name)(
|
||||
# replace_stride_with_dilation=[False, False, dilation],
|
||||
# pretrained=is_main_process(),
|
||||
# norm_layer=FrozenBatchNorm2d,
|
||||
# ) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
):
|
||||
|
||||
position_embedding = build_position_encoding(
|
||||
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
|
||||
)
|
||||
train_backbone = lr_backbone > 0
|
||||
return_interm_layers = masks
|
||||
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from shadow_act.models.transformer import Transformer
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
|
||||
self.state_dim, self.action_dim = state_dim, action_dim
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(
|
||||
backbones[0].num_channels, hidden_dim, kernel_size=1
|
||||
)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project qpos to embedding
|
||||
if self.vq:
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
|
||||
else:
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
if self.vq:
|
||||
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
|
||||
else:
|
||||
self.latent_out_proj = nn.Linear(
|
||||
self.latent_dim, hidden_dim
|
||||
) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
|
||||
bs, _ = qpos.shape
|
||||
if self.encoder is None:
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
|
||||
qpos.device
|
||||
)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
probs = binaries = mu = logvar = None
|
||||
else:
|
||||
# cvae encoder
|
||||
is_training = actions is not None # train or val
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(
|
||||
actions
|
||||
) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
|
||||
bs, 1, 1
|
||||
) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(
|
||||
1, 0, 2
|
||||
) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(
|
||||
qpos.device
|
||||
) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(
|
||||
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
|
||||
)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
|
||||
if self.vq:
|
||||
logits = latent_info.reshape(
|
||||
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
|
||||
)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
binaries = (
|
||||
F.one_hot(
|
||||
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
|
||||
-1
|
||||
),
|
||||
self.vq_dim,
|
||||
)
|
||||
.view(-1, self.vq_class, self.vq_dim)
|
||||
.float()
|
||||
)
|
||||
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
|
||||
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
|
||||
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
|
||||
latent_input = self.latent_out_proj(straigt_through)
|
||||
mu = logvar = None
|
||||
else:
|
||||
probs = binaries = None
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
else:
|
||||
mu = logvar = binaries = probs = None
|
||||
if self.vq:
|
||||
latent_input = self.latent_out_proj(
|
||||
vq_sample.view(-1, self.vq_class * self.vq_dim)
|
||||
)
|
||||
else:
|
||||
latent_sample = torch.zeros(
|
||||
[bs, self.latent_dim], dtype=torch.float32
|
||||
).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
return latent_input, probs, binaries, mu, logvar
|
||||
|
||||
def forward(
|
||||
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
|
||||
):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
|
||||
latent_input, probs, binaries, mu, logvar = self.encode(
|
||||
qpos, actions, is_pad, vq_sample
|
||||
)
|
||||
|
||||
# cvae decoder
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# TODO: fix this error
|
||||
features, pos = self.backbones[0](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(
|
||||
transformer_input, None, self.query_embed.weight, self.pos.weight
|
||||
)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5),
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(
|
||||
input_dim=mlp_in_dim,
|
||||
hidden_dim=1024,
|
||||
output_dim=self.action_dim,
|
||||
hidden_depth=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(
|
||||
hidden_dim, # 256
|
||||
dropout, # 0.1
|
||||
nheads, # 8
|
||||
dim_feedforward,
|
||||
num_encoder_layers, # 4 # TODO shared with VAE decoder
|
||||
normalize_before, # False
|
||||
):
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_vae(
|
||||
hidden_dim,
|
||||
state_dim,
|
||||
position_embedding_type,
|
||||
lr_backbone,
|
||||
masks,
|
||||
backbone,
|
||||
dilation,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
dec_layers,
|
||||
pre_norm,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
no_encoder,
|
||||
):
|
||||
# TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
)
|
||||
|
||||
if no_encoder:
|
||||
encoder = None
|
||||
else:
|
||||
encoder = build_encoder(
|
||||
hidden_dim,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
pre_norm,
|
||||
)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
|
||||
# TODO
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
|
||||
DROPOUT_RATE = 0.1 # 定义 dropout 率
|
||||
|
||||
# 定义一个因果变压器块
|
||||
class Causal_Transformer_Block(nn.Module):
|
||||
def __init__(self, seq_len, latent_dim, num_head) -> None:
|
||||
"""
|
||||
初始化因果变压器块
|
||||
|
||||
Args:
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int): 潜在维度
|
||||
num_head (int): 注意力头的数量
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.latent_dim = latent_dim
|
||||
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
|
||||
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
|
||||
nn.GELU(), # GELU 激活函数
|
||||
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
)
|
||||
|
||||
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
# 创建上三角掩码,防止信息泄露
|
||||
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
|
||||
x = self.ln_1(x) # 层归一化
|
||||
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
|
||||
x = self.ln_2(x) # 层归一化
|
||||
x = x + self.mlp(x) # 加上 MLP 输出
|
||||
|
||||
return x
|
||||
|
||||
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
|
||||
class Latent_Model_Transformer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
|
||||
"""
|
||||
初始化潜在模型变压器
|
||||
|
||||
Args:
|
||||
input_dim (int): 输入维度
|
||||
output_dim (int): 输出维度
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int, optional): 潜在维度,默认值为 256
|
||||
num_head (int, optional): 注意力头的数量,默认值为 8
|
||||
num_layer (int, optional): 变压器层的数量,默认值为 3
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.seq_len = seq_len
|
||||
self.latent_dim = latent_dim
|
||||
self.num_head = num_head
|
||||
self.num_layer = num_layer
|
||||
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
|
||||
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
|
||||
self.attention_blocks = nn.Sequential(
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
|
||||
nn.LayerNorm(latent_dim) # 层归一化
|
||||
)
|
||||
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
x = self.input_layer(x) # 输入层
|
||||
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
|
||||
x = self.attention_blocks(x) # 通过注意力块
|
||||
logits = self.output_layer(x) # 输出层
|
||||
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, n, temperature=0.1, x=None):
|
||||
"""
|
||||
生成序列
|
||||
|
||||
Args:
|
||||
n (int): 生成序列的数量
|
||||
temperature (float, optional): 采样温度,默认值为 0.1
|
||||
x (torch.Tensor, optional): 初始输入张量,默认值为 None
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 生成的序列
|
||||
"""
|
||||
if x is None:
|
||||
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
|
||||
for i in range(self.seq_len):
|
||||
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
|
||||
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
|
||||
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
|
||||
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
|
||||
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
|
||||
|
||||
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from shadow_act.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim, position_embedding_type):
|
||||
N_steps = hidden_dim // 2
|
||||
if position_embedding_type in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding_type in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding_type}")
|
||||
|
||||
return position_embedding
|
||||
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, encoder_norm
|
||||
)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer,
|
||||
num_decoder_layers,
|
||||
decoder_norm,
|
||||
return_intermediate=return_intermediate_dec,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
|
||||
1, bs, 1
|
||||
) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
memory_key_padding_mask=mask,
|
||||
pos=pos_embed,
|
||||
query_pos=query_embed,
|
||||
)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
pos=pos,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
):
|
||||
return Transformer(
|
||||
d_model=hidden_dim,
|
||||
dropout=dropout,
|
||||
nhead=nheads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=enc_layers,
|
||||
num_decoder_layers=dec_layers,
|
||||
normalize_before=pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,522 @@
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
|
||||
|
||||
# from diffusers.training_utils import EMAModel
|
||||
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
|
||||
|
||||
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: 重构DiffusionPolicy类
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化DiffusionPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = args_override["camera_names"]
|
||||
self.observation_horizon = args_override["observation_horizon"]
|
||||
self.action_horizon = args_override["action_horizon"]
|
||||
self.prediction_horizon = args_override["prediction_horizon"]
|
||||
self.num_inference_timesteps = args_override["num_inference_timesteps"]
|
||||
self.ema_power = args_override["ema_power"]
|
||||
self.lr = args_override["lr"]
|
||||
self.weight_decay = 0
|
||||
|
||||
self.num_kp = 32
|
||||
self.feature_dimension = 64
|
||||
self.ac_dim = args_override["action_dim"]
|
||||
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
|
||||
|
||||
backbones = []
|
||||
pools = []
|
||||
linears = []
|
||||
for _ in self.camera_names:
|
||||
backbones.append(
|
||||
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
|
||||
)
|
||||
pools.append(
|
||||
SpatialSoftmax(
|
||||
input_shape=[512, 15, 20],
|
||||
num_kp=self.num_kp,
|
||||
temperature=1.0,
|
||||
learnable_temperature=False,
|
||||
noise_std=0.0,
|
||||
)
|
||||
)
|
||||
linears.append(
|
||||
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
|
||||
)
|
||||
backbones = nn.ModuleList(backbones)
|
||||
pools = nn.ModuleList(pools)
|
||||
linears = nn.ModuleList(linears)
|
||||
|
||||
backbones = replace_bn_with_gn(backbones)
|
||||
|
||||
noise_pred_net = ConditionalUnet1D(
|
||||
input_dim=self.ac_dim,
|
||||
global_cond_dim=self.obs_dim * self.observation_horizon,
|
||||
)
|
||||
|
||||
nets = nn.ModuleDict(
|
||||
{
|
||||
"policy": nn.ModuleDict(
|
||||
{
|
||||
"backbones": backbones,
|
||||
"pools": pools,
|
||||
"linears": linears,
|
||||
"noise_pred_net": noise_pred_net,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
nets = nets.float().cuda()
|
||||
ENABLE_EMA = True
|
||||
if ENABLE_EMA:
|
||||
ema = EMAModel(model=nets, power=self.ema_power)
|
||||
else:
|
||||
ema = None
|
||||
self.nets = nets
|
||||
self.ema = ema
|
||||
|
||||
# 设置噪声调度器
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=50,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in self.parameters())
|
||||
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
B = qpos.shape[0]
|
||||
if actions is not None: # 训练时
|
||||
nets = self.nets
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 为动作添加噪声
|
||||
noise = torch.randn(actions.shape, device=obs_cond.device)
|
||||
|
||||
# 为每个数据点采样一个扩散迭代
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,),
|
||||
device=obs_cond.device,
|
||||
).long()
|
||||
|
||||
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
|
||||
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
|
||||
|
||||
# 预测噪声残差
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
noisy_actions, timesteps, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# L2损失
|
||||
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l2_loss"] = loss
|
||||
loss_dict["loss"] = loss
|
||||
|
||||
if self.training and self.ema is not None:
|
||||
self.ema.step(nets)
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
To = self.observation_horizon
|
||||
Ta = self.action_horizon
|
||||
Tp = self.prediction_horizon
|
||||
action_dim = self.ac_dim
|
||||
|
||||
nets = self.nets
|
||||
if self.ema is not None:
|
||||
nets = self.ema.averaged_model
|
||||
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 从高斯噪声初始化动作
|
||||
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
|
||||
naction = noisy_action
|
||||
|
||||
# 初始化调度器
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# 预测噪声
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
sample=naction, timestep=k, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# 逆扩散步骤(去除噪声)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return {
|
||||
"nets": self.nets.state_dict(),
|
||||
"ema": (
|
||||
self.ema.averaged_model.state_dict() if self.ema is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
status = self.nets.load_state_dict(model_dict["nets"])
|
||||
logger.info("Loaded model")
|
||||
if model_dict.get("ema", None) is not None:
|
||||
logger.info("Loaded EMA")
|
||||
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
|
||||
status = [status, status_ema]
|
||||
return status
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, act_config):
|
||||
"""
|
||||
初始化ACTPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
lr_backbone = act_config["lr_backbone"]
|
||||
vq = act_config["vq"]
|
||||
lr = act_config["lr"]
|
||||
weight_decay = act_config["weight_decay"]
|
||||
|
||||
model = build_vae(
|
||||
act_config["hidden_dim"],
|
||||
act_config["state_dim"],
|
||||
act_config["position_embedding"],
|
||||
lr_backbone,
|
||||
act_config["masks"],
|
||||
act_config["backbone"],
|
||||
act_config["dilation"],
|
||||
act_config["dropout"],
|
||||
act_config["nheads"],
|
||||
act_config["dim_feedforward"],
|
||||
act_config["enc_layers"],
|
||||
act_config["dec_layers"],
|
||||
act_config["pre_norm"],
|
||||
act_config["num_queries"],
|
||||
act_config["camera_names"],
|
||||
vq,
|
||||
act_config["vq_class"],
|
||||
act_config["vq_dim"],
|
||||
act_config["action_dim"],
|
||||
act_config["no_encoder"],
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
self.model = model # CVAE解码器
|
||||
self.kl_weight = act_config["kl_weight"]
|
||||
self.vq = vq
|
||||
logger.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 角度张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
vq_sample (torch.Tensor, optional): VQ样本
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
loss_dict = dict()
|
||||
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
|
||||
qpos, image, env_state, actions, is_pad
|
||||
)
|
||||
if self.vq or self.model.encoder is None:
|
||||
total_kld = [torch.tensor(0.0)]
|
||||
else:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
if self.vq:
|
||||
loss_dict["vq_discrepancy"] = F.l1_loss(
|
||||
probs, binaries, reduction="mean"
|
||||
)
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat, _, (_, _), _, _ = self.model(
|
||||
qpos, image, env_state, vq_sample=vq_sample
|
||||
) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def vq_encode(self, qpos, actions, is_pad):
|
||||
"""
|
||||
VQ编码
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
actions (torch.Tensor): 动作张量
|
||||
is_pad (torch.Tensor): 填充张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 二进制编码
|
||||
"""
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
|
||||
|
||||
return binaries
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return self.state_dict()
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
return self.load_state_dict(model_dict)
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化CNNMLPPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
# parser = argparse.ArgumentParser(
|
||||
# "DETR training and evaluation script", parents=[get_args_parser()]
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# for k, v in args_override.items():
|
||||
# setattr(args, k, v)
|
||||
|
||||
model = build_cnnmlp(args_override)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": args_override.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.model = model # 解码器
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
|
||||
)
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict["mse"] = mse
|
||||
loss_dict["loss"] = loss_dict["mse"]
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
"""
|
||||
计算KL散度
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): 均值张量
|
||||
logvar (torch.Tensor): 对数方差张量
|
||||
|
||||
Returns:
|
||||
tuple: 总KL散度,维度KL散度,均值KL散度
|
||||
"""
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
import yaml
|
||||
import pickle
|
||||
import torch
|
||||
# import wandb
|
||||
import logging
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
from itertools import repeat
|
||||
from shadow_act.utils.utils import (
|
||||
set_seed,
|
||||
load_data,
|
||||
compute_dict_mean,
|
||||
)
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.eval.rm_act_eval import RmActEvaluator
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class RmActTrainer:
|
||||
def __init__(self, config):
|
||||
"""
|
||||
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
|
||||
"""
|
||||
self._config = config
|
||||
self._num_steps = config["num_steps"]
|
||||
self._ckpt_dir = config["checkpoint_dir"]
|
||||
self._state_dim = config["state_dim"]
|
||||
self._real_robot = config["real_robot"]
|
||||
self._policy_class = config["policy_class"]
|
||||
self._onscreen_render = config["onscreen_render"]
|
||||
self._policy_config = config["policy_config"]
|
||||
self._camera_names = config["camera_names"]
|
||||
self._max_timesteps = config["episode_len"]
|
||||
self._task_name = config["task_name"]
|
||||
self._temporal_agg = config["temporal_agg"]
|
||||
self._onscreen_cam = "angle"
|
||||
self._vq = config["policy_config"]["vq"]
|
||||
self._batch_size = config["batch_size"]
|
||||
|
||||
self._seed = config["seed"]
|
||||
self._eval_every = config["eval_every"]
|
||||
self._validate_every = config["validate_every"]
|
||||
self._save_every = config["save_every"]
|
||||
self._load_pretrain = config["load_pretrain"]
|
||||
self._resume_ckpt_path = config["resume_ckpt_path"]
|
||||
|
||||
if config["name_filter"] is None:
|
||||
name_filter = lambda n : True
|
||||
else:
|
||||
name_filter = config["name_filter"]
|
||||
|
||||
self._eval = RmActEvaluator(config, True, 50)
|
||||
# 加载数据
|
||||
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
|
||||
config["dataset_dir"],
|
||||
name_filter,
|
||||
self._camera_names,
|
||||
self._batch_size,
|
||||
self._batch_size,
|
||||
config["chunk_size"],
|
||||
config["skip_mirrored_data"],
|
||||
self._load_pretrain,
|
||||
self._policy_class,
|
||||
config["stats_dir"],
|
||||
config["sample_weights"],
|
||||
config["train_ratio"],
|
||||
)
|
||||
# 保存数据统计信息
|
||||
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "wb") as f:
|
||||
pickle.dump(self._stats, f)
|
||||
expr_name = self._ckpt_dir.split("/")[-1]
|
||||
|
||||
# wandb.init(
|
||||
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
|
||||
# )
|
||||
|
||||
|
||||
def _make_policy(self):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
"""
|
||||
if self._policy_class == "ACT":
|
||||
return ACTPolicy(self._policy_config)
|
||||
elif self._policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(self._policy_config)
|
||||
elif self._policy_class == "Diffusion":
|
||||
return DiffusionPolicy(self._policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
|
||||
|
||||
def _make_optimizer(self):
|
||||
"""
|
||||
根据策略类创建优化器
|
||||
"""
|
||||
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
|
||||
return self._policy.configure_optimizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward_pass(self, data):
|
||||
"""
|
||||
前向传播,计算损失
|
||||
"""
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
try:
|
||||
image_data, qpos_data, action_data, is_pad = (
|
||||
image_data.cuda(),
|
||||
qpos_data.cuda(),
|
||||
action_data.cuda(),
|
||||
is_pad.cuda(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"CUDA error: {e}")
|
||||
raise
|
||||
return self._policy(qpos_data, image_data, action_data, is_pad)
|
||||
|
||||
def _repeater(self):
|
||||
"""
|
||||
数据加载器的重复器,生成数据
|
||||
"""
|
||||
epoch = 0
|
||||
for loader in repeat(self._train_dataloader):
|
||||
for data in loader:
|
||||
yield data
|
||||
logging.info(f"Epoch {epoch} done")
|
||||
epoch += 1
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
训练模型,保存最佳模型
|
||||
"""
|
||||
set_seed(self._seed)
|
||||
self._policy = self._make_policy()
|
||||
min_val_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
|
||||
# 加载预训练模型
|
||||
if self._load_pretrain:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(
|
||||
torch.load(
|
||||
os.path.join(
|
||||
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
|
||||
"policy_step_50000_seed_0.ckpt",
|
||||
)
|
||||
)
|
||||
)
|
||||
logging.info(f"loaded! {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Pretrain model not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading pretrain model: {e}")
|
||||
|
||||
# 恢复检查点
|
||||
if self._resume_ckpt_path is not None:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
|
||||
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Checkpoint not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading checkpoint: {e}")
|
||||
|
||||
self._policy.cuda()
|
||||
|
||||
self._optimizer = self._make_optimizer()
|
||||
train_dataloader = self._repeater() # 重复器
|
||||
|
||||
for step in tqdm(range(self._num_steps + 1)):
|
||||
# 验证模型
|
||||
if step % self._validate_every != 0:
|
||||
continue
|
||||
logging.info("validating")
|
||||
with torch.inference_mode():
|
||||
self._policy.eval()
|
||||
validation_dicts = []
|
||||
for batch_idx, data in enumerate(self._val_dataloader):
|
||||
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
|
||||
validation_dicts.append(forward_dict)
|
||||
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
|
||||
break
|
||||
|
||||
validation_summary = compute_dict_mean(validation_dicts)
|
||||
epoch_val_loss = validation_summary["loss"]
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
best_ckpt_info = (
|
||||
step,
|
||||
min_val_loss,
|
||||
deepcopy(self._policy.serialize()),
|
||||
)
|
||||
|
||||
# wandb记录验证结果
|
||||
# for k in list(validation_summary.keys()):
|
||||
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
|
||||
|
||||
# wandb.log(validation_summary, step=step)
|
||||
logging.info(f"Val loss: {epoch_val_loss:.5f}")
|
||||
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
|
||||
logging.info(summary_string)
|
||||
|
||||
# 评估模型
|
||||
# if (step > 0) and (step % self._eval_every == 0):
|
||||
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
|
||||
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
|
||||
# torch.save(self._policy.serialize(), ckpt_path)
|
||||
# success, _ = self._eval.evaluate(ckpt_name)
|
||||
# wandb.log({"success": success}, step=step)
|
||||
|
||||
# 训练模型
|
||||
self._policy.train()
|
||||
self._optimizer.zero_grad()
|
||||
data = next(train_dataloader)
|
||||
forward_dict = self._forward_pass(data)
|
||||
loss = forward_dict["loss"]
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
# wandb.log(forward_dict, step=step)
|
||||
|
||||
# 保存检查点
|
||||
if step % self._save_every == 0:
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
# 保存最后的模型
|
||||
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
best_step, min_val_loss, best_state_dict = best_ckpt_info
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = RmActTrainer(config)
|
||||
trainer.train()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -0,0 +1,499 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
import pickle
|
||||
import fnmatch
|
||||
import cv2
|
||||
from time import time
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
|
||||
|
||||
def flatten_list(l):
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
episode_ids,
|
||||
episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_path_list = dataset_path_list
|
||||
self.camera_names = camera_names
|
||||
self.norm_stats = norm_stats
|
||||
self.episode_len = episode_len
|
||||
self.chunk_size = chunk_size
|
||||
self.cumulative_len = np.cumsum(self.episode_len)
|
||||
self.max_episode_len = max(episode_len)
|
||||
self.policy_class = policy_class
|
||||
if self.policy_class == "Diffusion":
|
||||
self.augment_images = True
|
||||
else:
|
||||
self.augment_images = False
|
||||
self.transformations = None
|
||||
self.__getitem__(0) # initialize self.is_sim and self.transformations
|
||||
self.is_sim = False
|
||||
|
||||
# def __len__(self):
|
||||
# return sum(self.episode_len)
|
||||
|
||||
def _locate_transition(self, index):
|
||||
assert index < self.cumulative_len[-1]
|
||||
episode_index = np.argmax(
|
||||
self.cumulative_len > index
|
||||
) # argmax returns first True index
|
||||
start_ts = index - (
|
||||
self.cumulative_len[episode_index] - self.episode_len[episode_index]
|
||||
)
|
||||
episode_id = self.episode_ids[episode_index]
|
||||
return episode_id, start_ts
|
||||
|
||||
def __getitem__(self, index):
|
||||
episode_id, start_ts = self._locate_transition(index)
|
||||
dataset_path = self.dataset_path_list[episode_id]
|
||||
try:
|
||||
# print(dataset_path)
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
try: # some legacy data does not have this attribute
|
||||
is_sim = root.attrs["sim"]
|
||||
except:
|
||||
is_sim = False
|
||||
compressed = root.attrs.get("compress", False)
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
base_action = preprocess_base_action(base_action)
|
||||
action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
original_action_shape = action.shape
|
||||
episode_len = original_action_shape[0]
|
||||
# get observation at start_ts only
|
||||
qpos = root["/observations/qpos"][start_ts]
|
||||
qvel = root["/observations/qvel"][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in self.camera_names:
|
||||
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
|
||||
start_ts
|
||||
]
|
||||
|
||||
if compressed:
|
||||
for cam_name in image_dict.keys():
|
||||
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
|
||||
image_dict[cam_name] = np.array(decompressed_image)
|
||||
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = action[start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = action[
|
||||
max(0, start_ts - 1) :
|
||||
] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(
|
||||
0, start_ts - 1
|
||||
) # hack, to make timesteps more aligned
|
||||
|
||||
# self.is_sim = is_sim
|
||||
padded_action = np.zeros(
|
||||
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
|
||||
)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(self.max_episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
|
||||
padded_action = padded_action[: self.chunk_size]
|
||||
is_pad = is_pad[: self.chunk_size]
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
qpos_data = torch.from_numpy(qpos).float()
|
||||
action_data = torch.from_numpy(padded_action).float()
|
||||
is_pad = torch.from_numpy(is_pad).bool()
|
||||
|
||||
# channel last
|
||||
image_data = torch.einsum("k h w c -> k c h w", image_data)
|
||||
|
||||
# augmentation
|
||||
if self.transformations is None:
|
||||
print("Initializing transformations")
|
||||
original_size = image_data.shape[2:]
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.RandomCrop(
|
||||
size=[
|
||||
int(original_size[0] * ratio),
|
||||
int(original_size[1] * ratio),
|
||||
]
|
||||
),
|
||||
transforms.Resize(original_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.4, saturation=0.5
|
||||
), # , hue=0.08)
|
||||
]
|
||||
|
||||
if self.augment_images:
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# normalize image and change dtype to float
|
||||
image_data = image_data / 255.0
|
||||
|
||||
if self.policy_class == "Diffusion":
|
||||
# normalize to [-1, 1]
|
||||
action_data = (
|
||||
(action_data - self.norm_stats["action_min"])
|
||||
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
|
||||
) * 2 - 1
|
||||
else:
|
||||
# normalize to mean 0 std 1
|
||||
action_data = (
|
||||
action_data - self.norm_stats["action_mean"]
|
||||
) / self.norm_stats["action_std"]
|
||||
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
|
||||
"qpos_std"
|
||||
]
|
||||
|
||||
except:
|
||||
print(f"Error loading {dataset_path} in __getitem__")
|
||||
quit()
|
||||
|
||||
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
|
||||
|
||||
def get_norm_stats(dataset_path_list):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
all_episode_len = []
|
||||
|
||||
for dataset_path in dataset_path_list:
|
||||
try:
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
qpos = root["/observations/qpos"][()]
|
||||
qvel = root["/observations/qvel"][()]
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
# base_action = preprocess_base_action(base_action)
|
||||
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
except Exception as e:
|
||||
print(f"Error loading {dataset_path} in get_norm_stats")
|
||||
print(e)
|
||||
quit()
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_episode_len.append(len(qpos))
|
||||
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||
all_action_data = torch.cat(all_action_data, dim=0)
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0]).float()
|
||||
action_std = all_action_data.std(dim=[0]).float()
|
||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0]).float()
|
||||
qpos_std = all_qpos_data.std(dim=[0]).float()
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||
|
||||
action_min = all_action_data.min(dim=0).values.float()
|
||||
action_max = all_action_data.max(dim=0).values.float()
|
||||
|
||||
eps = 0.0001
|
||||
stats = {
|
||||
"action_mean": action_mean.numpy(),
|
||||
"action_std": action_std.numpy(),
|
||||
"action_min": action_min.numpy() - eps,
|
||||
"action_max": action_max.numpy() + eps,
|
||||
"qpos_mean": qpos_mean.numpy(),
|
||||
"qpos_std": qpos_std.numpy(),
|
||||
"example_qpos": qpos,
|
||||
}
|
||||
|
||||
return stats, all_episode_len
|
||||
|
||||
|
||||
def find_all_hdf5(dataset_dir, skip_mirrored_data):
|
||||
hdf5_files = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
if "features" in filename:
|
||||
continue
|
||||
if skip_mirrored_data and "mirror" in filename:
|
||||
continue
|
||||
hdf5_files.append(os.path.join(root, filename))
|
||||
print(f"Found {len(hdf5_files)} hdf5 files")
|
||||
return hdf5_files
|
||||
|
||||
|
||||
def BatchSampler(batch_size, episode_len_l, sample_weights):
|
||||
sample_probs = (
|
||||
np.array(sample_weights) / np.sum(sample_weights)
|
||||
if sample_weights is not None
|
||||
else None
|
||||
)
|
||||
# print("BatchSampler", sample_weights)
|
||||
sum_dataset_len_l = np.cumsum(
|
||||
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
|
||||
)
|
||||
while True:
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
|
||||
step_idx = np.random.randint(
|
||||
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
|
||||
)
|
||||
batch.append(step_idx)
|
||||
yield batch
|
||||
|
||||
|
||||
def load_data(
|
||||
dataset_dir_l,
|
||||
name_filter,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
chunk_size,
|
||||
skip_mirrored_data=False,
|
||||
load_pretrain=False,
|
||||
policy_class=None,
|
||||
stats_dir_l=None,
|
||||
sample_weights=None,
|
||||
train_ratio=0.99,
|
||||
):
|
||||
if type(dataset_dir_l) == str:
|
||||
dataset_dir_l = [dataset_dir_l]
|
||||
dataset_path_list_list = [
|
||||
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
|
||||
]
|
||||
num_episodes_0 = len(dataset_path_list_list[0])
|
||||
dataset_path_list = flatten_list(dataset_path_list_list)
|
||||
|
||||
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
|
||||
num_episodes_l = [
|
||||
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
|
||||
]
|
||||
num_episodes_cumsum = np.cumsum(num_episodes_l)
|
||||
|
||||
# obtain train test split on dataset_dir_l[0]
|
||||
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
|
||||
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
|
||||
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
|
||||
train_episode_ids_l = [train_episode_ids_0] + [
|
||||
np.arange(num_episodes) + num_episodes_cumsum[idx]
|
||||
for idx, num_episodes in enumerate(num_episodes_l[1:])
|
||||
]
|
||||
val_episode_ids_l = [val_episode_ids_0]
|
||||
train_episode_ids = np.concatenate(train_episode_ids_l)
|
||||
val_episode_ids = np.concatenate(val_episode_ids_l)
|
||||
print(
|
||||
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
|
||||
)
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
# if load_pretrain:
|
||||
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
|
||||
# norm_stats = pickle.load(f)
|
||||
# print('Loaded pretrain dataset stats')
|
||||
_, all_episode_len = get_norm_stats(dataset_path_list)
|
||||
train_episode_len_l = [
|
||||
[all_episode_len[i] for i in train_episode_ids]
|
||||
for train_episode_ids in train_episode_ids_l
|
||||
]
|
||||
val_episode_len_l = [
|
||||
[all_episode_len[i] for i in val_episode_ids]
|
||||
for val_episode_ids in val_episode_ids_l
|
||||
]
|
||||
|
||||
train_episode_len = flatten_list(train_episode_len_l)
|
||||
val_episode_len = flatten_list(val_episode_len_l)
|
||||
if stats_dir_l is None:
|
||||
stats_dir_l = dataset_dir_l
|
||||
elif type(stats_dir_l) == str:
|
||||
stats_dir_l = [stats_dir_l]
|
||||
norm_stats, _ = get_norm_stats(
|
||||
flatten_list(
|
||||
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
|
||||
)
|
||||
)
|
||||
print(f"Norm stats from: {stats_dir_l}")
|
||||
|
||||
batch_sampler_train = BatchSampler(
|
||||
batch_size_train, train_episode_len_l, sample_weights
|
||||
)
|
||||
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
|
||||
|
||||
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
train_episode_ids,
|
||||
train_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
val_episode_ids,
|
||||
val_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
train_num_workers = (
|
||||
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
|
||||
)
|
||||
val_num_workers = 8 if train_dataset.augment_images else 2
|
||||
print(
|
||||
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler_train,
|
||||
pin_memory=True,
|
||||
num_workers=train_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_sampler=batch_sampler_val,
|
||||
pin_memory=True,
|
||||
num_workers=val_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||
|
||||
|
||||
def calibrate_linear_vel(base_action, c=None):
|
||||
if c is None:
|
||||
c = 0.0 # 0.19
|
||||
v = base_action[..., 0]
|
||||
w = base_action[..., 1]
|
||||
base_action = base_action.copy()
|
||||
base_action[..., 0] = v - c * w
|
||||
return base_action
|
||||
|
||||
|
||||
def smooth_base_action(base_action):
|
||||
return np.stack(
|
||||
[
|
||||
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
|
||||
for i in range(base_action.shape[1])
|
||||
],
|
||||
axis=-1,
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def preprocess_base_action(base_action):
|
||||
# base_action = calibrate_linear_vel(base_action)
|
||||
base_action = smooth_base_action(base_action)
|
||||
|
||||
return base_action
|
||||
|
||||
|
||||
def postprocess_base_action(base_action):
|
||||
linear_vel, angular_vel = base_action
|
||||
linear_vel *= 1.0
|
||||
angular_vel *= 1.0
|
||||
# angular_vel = 0
|
||||
# if np.abs(linear_vel) < 0.05:
|
||||
# linear_vel = 0
|
||||
return np.array([linear_vel, angular_vel])
|
||||
|
||||
|
||||
### env utils
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
|
||||
|
||||
### helper functions
|
||||
|
||||
|
||||
def compute_dict_mean(epoch_dicts):
|
||||
result = {k: None for k in epoch_dicts[0]}
|
||||
num_items = len(epoch_dicts)
|
||||
for k in result:
|
||||
value_sum = 0
|
||||
for epoch_dict in epoch_dicts:
|
||||
value_sum += epoch_dict[k]
|
||||
result[k] = value_sum / num_items
|
||||
return result
|
||||
|
||||
|
||||
def detach_dict(d):
|
||||
new_d = dict()
|
||||
for k, v in d.items():
|
||||
new_d[k] = v.detach()
|
||||
return new_d
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import yaml
|
||||
import time
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import collections
|
||||
import logging
|
||||
import dm_env
|
||||
import tracemalloc
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
|
||||
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
|
||||
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
|
||||
return {
|
||||
"cam_high": top_image,
|
||||
"cam_low": bottom_image,
|
||||
"cam_left": left_image,
|
||||
"cam_right": right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
# self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
|
||||
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:6])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[7:13])
|
||||
self.arm_left.set_gripper_position(target_angle[6])
|
||||
self.arm_right.set_gripper_position(target_angle[13])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
aloha_config = config["robot_env"]
|
||||
device = DeviceAloha(aloha_config)
|
||||
device.reset()
|
||||
image_list = []
|
||||
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
while True:
|
||||
tracemalloc.start() # 启动内存跟踪
|
||||
|
||||
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
|
||||
time_step = time.time()
|
||||
timestep = device.step(tager_angle)
|
||||
logging.info(f"Time: {time.time() - time_step}")
|
||||
image_list.append(timestep.observation["images"])
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics('lineno')
|
||||
# del timestep
|
||||
print("[ Top 10 ]")
|
||||
for stat in top_stats[:10]:
|
||||
print(stat)
|
||||
# logging.info(f"Images: {obs}")
|
||||
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
# import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import (
|
||||
load_data,
|
||||
sample_box_pose,
|
||||
sample_insertion_pose,
|
||||
compute_dict_mean,
|
||||
set_seed,
|
||||
detach_dict,
|
||||
)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print('daasdas')
|
||||
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from constants import DT
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
for cam_name in cam_names:
|
||||
image = image_dict[cam_name]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
images.append(image)
|
||||
images = np.concatenate(images, axis=1)
|
||||
out.write(images)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(t_list, dataset_path):
|
||||
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||
h, w = 4, 10
|
||||
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||
# process t_list
|
||||
t_float = []
|
||||
for secs, nsecs in t_list:
|
||||
t_float.append(secs + nsecs * 10E-10)
|
||||
t_float = np.array(t_float)
|
||||
|
||||
ax = axs[0]
|
||||
ax.plot(np.arange(len(t_float)), t_float)
|
||||
ax.set_title(f'Camera frame timestamps')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
ax = axs[1]
|
||||
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||
ax.set_title(f'dt')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||
main(vars(parser.parse_args()))
|
||||
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
@@ -0,0 +1 @@
|
||||
aloha_data_synchronizer.py
|
||||
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,17 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="4">
|
||||
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="1" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="3" class="java.lang.String" itemvalue="torchvision" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (随箱软件)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
|
||||
dataset_name: 'episode'
|
||||
max_timesteps: 500
|
||||
state_dim: 14
|
||||
overwrite: False
|
||||
arm_axis: 6
|
||||
camera_names:
|
||||
- 'cam_high'
|
||||
- 'cam_low'
|
||||
- 'cam_left'
|
||||
- 'cam_right'
|
||||
ros_topics:
|
||||
camera_left: '/camera_left/rgb/image_raw'
|
||||
camera_right: '/camera_right/rgb/image_raw'
|
||||
camera_bottom: '/camera_bottom/rgb/image_raw'
|
||||
camera_head: '/camera_head/rgb/image_raw'
|
||||
left_master_arm: '/left_master_arm_joint_states'
|
||||
left_slave_arm: '/left_slave_arm_joint_states'
|
||||
right_master_arm: '/right_master_arm_joint_states'
|
||||
right_slave_arm: '/right_slave_arm_joint_states'
|
||||
left_aloha_state: '/left_slave_arm_aloha_state'
|
||||
right_aloha_state: '/right_slave_arm_aloha_state'
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '241122071186',
|
||||
bottom_camera: '152122078546',
|
||||
left_camera: '150622070125',
|
||||
right_camera: '151222072576',
|
||||
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
|
||||
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
|
||||
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
|
||||
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
|
||||
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8089
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.19"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8090
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB0
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,6 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
|
||||
dataset_name: 'episode'
|
||||
episode_idx: 1
|
||||
FPS: 30
|
||||
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
|
||||
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints
|
||||
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_aloha"
|
||||
version = "0.1.1"
|
||||
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
matplotlib = ">=3.9.2"
|
||||
h5py = ">=3.12.1"
|
||||
# rospy = ">=1.17.0"
|
||||
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
|
||||
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,42 @@
|
||||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(shadow_rm_aloha)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
rospy
|
||||
sensor_msgs
|
||||
cv_bridge
|
||||
image_transport
|
||||
std_msgs
|
||||
message_generation
|
||||
)
|
||||
|
||||
add_service_files(
|
||||
FILES
|
||||
GetArmStatus.srv
|
||||
GetImage.srv
|
||||
MoveArm.srv
|
||||
)
|
||||
|
||||
generate_messages(
|
||||
DEPENDENCIES
|
||||
sensor_msgs
|
||||
std_msgs
|
||||
)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS message_runtime rospy std_msgs
|
||||
)
|
||||
|
||||
include_directories(
|
||||
${catkin_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
install(PROGRAMS
|
||||
arm_node/slave_arm_publisher.py
|
||||
arm_node/master_arm_publisher.py
|
||||
arm_node/slave_arm_pub_sub.py
|
||||
camera_node/camera_publisher.py
|
||||
data_sub_process/aloha_data_synchronizer.py
|
||||
data_sub_process/aloha_data_collect.py
|
||||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class MasterArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("master_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
|
||||
self.arm = ServoArm(arm_config)
|
||||
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
joint_pos = self.arm.get_joint_actions()
|
||||
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = list(joint_pos.keys())
|
||||
joint_state.position = list(joint_pos.values())
|
||||
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
|
||||
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
|
||||
|
||||
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = MasterArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
# data = self.arm.get_arm_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||
# joint_state.position = data["joint_angle"]
|
||||
joint_state.position = data['arm_angle']
|
||||
|
||||
# joint_state.position = list(data["arm_angle"])
|
||||
# joint_state.velocity = list(data["arm_velocity"])
|
||||
# joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
def callback(self, data):
|
||||
# rospy.loginfo(f"Received joint_states_topic: {data}")
|
||||
start_time = rospy.Time.now()
|
||||
if data is None:
|
||||
return
|
||||
if data.name == ["joint_canfd"]:
|
||||
self.arm.set_joint_canfd_position(data.position[0:6])
|
||||
elif data.name == ["joint_j"]:
|
||||
self.arm.set_joint_position(data.position[0:6])
|
||||
|
||||
# self.arm.set_gripper_position(data.position[6])
|
||||
end_time = rospy.Time.now()
|
||||
time_cost_ms = (end_time - start_time).to_sec() * 1000
|
||||
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
|
||||
joint_state.position = list(data["arm_angle"])
|
||||
joint_state.velocity = list(data["arm_velocity"])
|
||||
joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Header
|
||||
import numpy as np
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
class CameraPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node('camera_publisher', anonymous=True)
|
||||
self.serial_number = rospy.get_param('~serial_number', None)
|
||||
hz = rospy.get_param('~hz', 30)
|
||||
rospy.loginfo(f"Serial number: {self.serial_number}")
|
||||
|
||||
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
|
||||
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
|
||||
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
|
||||
rospy.loginfo(f"Depth topic: {self.depth_topic}")
|
||||
|
||||
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
|
||||
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
self.camera = RealSenseCamera(self.serial_number, False)
|
||||
|
||||
rospy.loginfo("Camera initialized")
|
||||
|
||||
def publish_images(self):
|
||||
self.camera.start_camera()
|
||||
rospy.loginfo("Camera started")
|
||||
while not rospy.is_shutdown():
|
||||
result = self.camera.read_frame(True, False, False, False)
|
||||
if result is None:
|
||||
rospy.logerr("Failed to read frame from camera")
|
||||
continue
|
||||
|
||||
color_image, depth_image, _, _ = result
|
||||
|
||||
if color_image is not None or depth_image is not None:
|
||||
rgb_msg = self.create_image_msg(color_image, "bgr8")
|
||||
# depth_msg = self.create_image_msg(depth_image, "mono16")
|
||||
|
||||
self.rgb_pub.publish(rgb_msg)
|
||||
# self.depth_pub.publish(depth_msg)
|
||||
# rospy.loginfo("Published RGB image")
|
||||
else:
|
||||
rospy.logwarn("Received None for color_image or depth_image")
|
||||
|
||||
self.rate.sleep()
|
||||
self.camera.stop_camera()
|
||||
rospy.loginfo("Camera stopped")
|
||||
|
||||
def create_image_msg(self, image, encoding):
|
||||
msg = Image()
|
||||
msg.header = Header()
|
||||
msg.header.stamp = rospy.Time.now()
|
||||
msg.height, msg.width = image.shape[:2]
|
||||
msg.encoding = encoding
|
||||
msg.is_bigendian = False
|
||||
msg.step = image.strides[0]
|
||||
msg.data = np.array(image).tobytes()
|
||||
return msg
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
camera_publisher = CameraPublisher()
|
||||
camera_publisher.publish_images()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataCollector:
|
||||
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
|
||||
self.arm_axis = 7
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/observations/effort': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# 自动检测和创建数据集目录
|
||||
self.create_dataset_dir()
|
||||
self.timesteps_collected = 0
|
||||
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
self.data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||
self.data_dict['/action'].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
def save_data(self):
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
|
||||
root.attrs['sim'] = False
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3))
|
||||
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs')
|
||||
return True
|
||||
|
||||
def load_hdf5(self, orign_path, file):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file)
|
||||
if not os.path.isfile(orign_path):
|
||||
logging.error(f'Dataset does not exist at {orign_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(orign_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
self.qvel = root['/observations/qvel'][()]
|
||||
self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
|
||||
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
|
||||
|
||||
self.data_dict['/observations/qpos'] = self.qpos
|
||||
self.data_dict['/observations/qvel'] = self.qvel
|
||||
self.data_dict['/observations/effort'] = self.effort
|
||||
self.data_dict['/action'] = self.action
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
|
||||
|
||||
"""
|
||||
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
|
||||
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
|
||||
dataset_name = 'test'
|
||||
max_timesteps = 300
|
||||
camera_names = ['cam_high','cam_low','cam_left','cam_right']
|
||||
state_dim = 16
|
||||
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
for file in os.listdir(orign_dir):
|
||||
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
orign_path = os.path.join(orign_dir, file)
|
||||
print(orign_path)
|
||||
collector.load_hdf5(orign_path, file)
|
||||
collector.save_data()
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import time
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataValidator:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.joint_names = config['joint_names']
|
||||
self.dataset_name = f'episode_{self.episode_idx}'
|
||||
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def validate_data(self):
|
||||
# 验证位置数据
|
||||
if not self.qpos.shape[1] == 14:
|
||||
logging.error('qpos shape does not match expected number of joints')
|
||||
return False
|
||||
|
||||
logging.info('Data validation passed')
|
||||
return True
|
||||
|
||||
def control_robot(self):
|
||||
self.arm.set_joint_position(self.qpos[0][0:6])
|
||||
for qpos in self.qpos:
|
||||
logging.info(f'qpos: {qpos}')
|
||||
self.arm.set_joint_canfd_position(qpos[7:13])
|
||||
self.arm.set_gripper_position(qpos[13])
|
||||
time.sleep(0.035)
|
||||
|
||||
|
||||
def run(self):
|
||||
self.load_hdf5()
|
||||
if self.validate_data():
|
||||
self.control_robot()
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
validator = DataValidator(config)
|
||||
validator.run()
|
||||
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataVisualizer:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.dt = 1/config['FPS']
|
||||
self.joint_names = config['joint_names']
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
# self.camera_names = config['camera_names']
|
||||
|
||||
def join_file_path(self, file_name):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file_name)
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def save_videos(self, video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for image_dict in video:
|
||||
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
|
||||
out.write(np.concatenate(images, axis=1))
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
|
||||
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
|
||||
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
|
||||
qpos = np.array(qpos_list)
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_single(self, data_list, label, plot_path, ylim=None):
|
||||
data = np.array(data_list)
|
||||
num_ts, num_dim = data.shape
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(data[:, dim_idx], label=label)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved {label} plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(self, t_list):
|
||||
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
|
||||
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
|
||||
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
|
||||
|
||||
axs[0].plot(np.arange(len(t_float)), t_float)
|
||||
axs[0].set_title('Camera frame timestamps')
|
||||
axs[0].set_xlabel('timestep')
|
||||
axs[0].set_ylabel('time (sec)')
|
||||
|
||||
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
|
||||
axs[1].set_title('dt')
|
||||
axs[1].set_xlabel('timestep')
|
||||
axs[1].set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def run(self):
|
||||
for file_name in os.listdir(self.dataset_dir):
|
||||
if file_name.endswith('.hdf5'):
|
||||
self.join_file_path(file_name)
|
||||
self.load_hdf5()
|
||||
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
|
||||
self.save_videos(self.image_dict, self.dt, video_path)
|
||||
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
|
||||
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
|
||||
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
|
||||
# self.visualize_single(self.effort, 'effort', effort_plot_path)
|
||||
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
|
||||
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
|
||||
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
visualizer = DataVisualizer(config)
|
||||
visualizer.run()
|
||||
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
|
||||
self.arm_axis = config["robot_env"]["arm_axis"]
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm_sub"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm_sub"], JointState
|
||||
)
|
||||
self.left_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
self.right_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.1,
|
||||
)
|
||||
|
||||
self.ats.registerCallback(self.callback)
|
||||
self.ts = None
|
||||
self.is_frist_step = True
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_slave_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.ts = dm_env.TimeStep(
|
||||
step_type=(
|
||||
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
|
||||
),
|
||||
reward=0.0,
|
||||
discount=1.0,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_j"]
|
||||
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_j"]
|
||||
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
while self.ts is None:
|
||||
time.sleep(0.002)
|
||||
|
||||
return self.ts
|
||||
|
||||
|
||||
|
||||
def step(self, target_angle):
|
||||
self.is_frist_step = False
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_canfd"]
|
||||
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
|
||||
# print("left_joint_state: ", left_joint_state)
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_canfd"]
|
||||
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
|
||||
# print("right_joint_state: ", right_joint_state)
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
# time.sleep(0.013)
|
||||
return self.ts
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
|
||||
self.reset()
|
||||
# print("data: ", data)
|
||||
while not rospy.is_shutdown():
|
||||
self.step(data)
|
||||
rospy.sleep(0.010)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
|
||||
start_time = time.time()
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
print(self.camera_names[0])
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,63 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,49 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,61 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322070299" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322074992" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215322076086" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215222074360" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import json
|
||||
import rospy
|
||||
import dm_env
|
||||
import socket
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
"cam_front": camera_head_np_img,
|
||||
"cam_low": camera_bottom_np_img,
|
||||
"cam_left": camera_left_np_img,
|
||||
"cam_right": camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
<?xml version="1.0"?>
|
||||
<package format="2">
|
||||
<name>shadow_rm_aloha</name>
|
||||
<version>0.0.1</version>
|
||||
<description>The shadow_rm_aloha package</description>
|
||||
|
||||
<maintainer email="your_email@example.com">Your Name</maintainer>
|
||||
|
||||
<license>TODO</license>
|
||||
|
||||
<buildtool_depend>catkin</buildtool_depend>
|
||||
|
||||
<build_depend>rospy</build_depend>
|
||||
<build_depend>sensor_msgs</build_depend>
|
||||
<build_depend>std_msgs</build_depend>
|
||||
<build_depend>cv_bridge</build_depend>
|
||||
<build_depend>image_transport</build_depend>
|
||||
<build_depend>message_generation</build_depend>
|
||||
<build_depend>message_runtime</build_depend>
|
||||
|
||||
<exec_depend>rospy</exec_depend>
|
||||
<exec_depend>sensor_msgs</exec_depend>
|
||||
<exec_depend>std_msgs</exec_depend>
|
||||
<exec_depend>cv_bridge</exec_depend>
|
||||
<exec_depend>image_transport</exec_depend>
|
||||
<exec_depend>message_runtime</exec_depend>
|
||||
|
||||
|
||||
<export>
|
||||
</export>
|
||||
</package>
|
||||
@@ -0,0 +1,5 @@
|
||||
# GetArmStatus.srv
|
||||
|
||||
---
|
||||
sensor_msgs/JointState joint_status
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user