使用原生的数据搜集代码

This commit is contained in:
2025-04-13 21:41:45 +08:00
parent 3df284ddd1
commit a4fe5ee09a
20 changed files with 1477 additions and 1544 deletions

272
collect_data/utils.py Normal file
View File

@@ -0,0 +1,272 @@
import cv2
import numpy as np
import h5py
import time
def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0):
"""
显示多摄像头画面(保持原始比例,但可整体缩放)
参数:
image_dict: {摄像头名称: 图像numpy数组}
grid_shape: (行, 列) 布局None自动计算
window_name: 窗口名称
scale: 整体显示缩放比例0.5表示显示为原尺寸的50%
"""
# 输入验证和数据处理(保持原代码不变)
if not isinstance(image_dict, dict):
raise TypeError("输入必须是字典类型")
valid_data = []
for name, img in image_dict.items():
if not isinstance(img, np.ndarray):
continue
if img.dtype != np.uint8:
img = img.astype(np.uint8)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
elif img.shape[2] == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
valid_data.append((name, img))
if not valid_data:
print("错误: 没有有效的图像可显示!")
return None
# 自动计算网格布局
num_valid = len(valid_data)
if grid_shape is None:
grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2)))
rows, cols = grid_shape
# 计算每行/列的最大尺寸
row_heights = [0]*rows
col_widths = [0]*cols
for i, (_, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
row_heights[r] = max(row_heights[r], img.shape[0])
col_widths[c] = max(col_widths[c], img.shape[1])
# 计算画布总尺寸(应用整体缩放)
canvas_h = int(sum(row_heights) * scale)
canvas_w = int(sum(col_widths) * scale)
# 创建画布
canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
# 计算每个子画面的显示区域
row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)]
col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)]
# 填充图像
for i, (name, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
# 计算当前图像的显示区域
x1, x2 = col_pos[c], col_pos[c+1]
y1, y2 = row_pos[r], row_pos[r+1]
# 计算当前图像的缩放后尺寸
display_h = int(img.shape[0] * scale)
display_w = int(img.shape[1] * scale)
# 缩放图像(保持比例)
resized_img = cv2.resize(img, (display_w, display_h))
# 放置到画布
canvas[y1:y1+display_h, x1:x1+display_w] = resized_img
# 添加标签(按比例缩放字体)
font_scale = 0.8 *scale
thickness = max(2, int(2 * scale))
cv2.putText(canvas, name, (x1+10, y1+30),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness)
# 显示窗口(自动适应屏幕)
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.imshow(window_name, canvas)
cv2.resizeWindow(window_name, canvas_w, canvas_h)
cv2.waitKey(1)
return canvas
# 保存数据函数
def save_data(args, timesteps, actions, dataset_path):
# 数据字典
data_size = len(actions)
data_dict = {
# 一个是奖励里面的qposqvel effort ,一个是实际发的acition
'/observations/qpos': [],
'/observations/qvel': [],
'/observations/effort': [],
'/action': [],
'/base_action': [],
# '/base_action_t265': [],
}
# 相机字典 观察的图像
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'] = []
# len(action): max_timesteps, len(time_steps): max_timesteps + 1
# 动作长度 遍历动作
while actions:
# 循环弹出一个队列
action = actions.pop(0) # 动作 当前动作
ts = timesteps.pop(0) # 奖励 前一帧
# 往字典里面添值
# Timestep返回的qposqvel,effort
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/observations/effort'].append(ts.observation['effort'])
# 实际发的action
data_dict['/action'].append(action)
data_dict['/base_action'].append(ts.observation['base_vel'])
# 相机数据
# data_dict['/base_action_t265'].append(ts.observation['base_vel_t265'])
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name])
t0 = time.time()
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
# 文本的属性:
# 1 是否仿真
# 2 图像是否压缩
#
root.attrs['sim'] = False
root.attrs['compress'] = False
# 创建一个新的组observations观测状态组
# 图像组
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in args.camera_names:
_ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
if args.use_depth_image:
image_depth = obs.create_group('images_depth')
for cam_name in args.camera_names:
_ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16',
chunks=(1, 480, 640), )
_ = obs.create_dataset('qpos', (data_size, 14))
_ = obs.create_dataset('qvel', (data_size, 14))
_ = obs.create_dataset('effort', (data_size, 14))
_ = root.create_dataset('action', (data_size, 14))
_ = root.create_dataset('base_action', (data_size, 2))
# data_dict write into h5py.File
for name, array in data_dict.items():
root[name][...] = array
print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path)
def is_headless():
"""
Check if the environment is headless (no display available).
Returns:
bool: True if the environment is headless, False otherwise.
"""
try:
import tkinter as tk
root = tk.Tk()
root.withdraw()
root.update()
root.destroy()
return False
except:
return True
def init_keyboard_listener():
"""
Initialize keyboard listener for control events with new key mappings:
- Left arrow: Start data recording
- Right arrow: Save current data
- Down arrow: Discard current data
- Up arrow: Replay current data
- ESC: Early termination
Returns:
tuple: (listener, events) - Keyboard listener and events dictionary
"""
events = {
"exit_early": False,
"record_start": False,
"save_data": False,
"discard_data": False,
"replay_data": False
}
if is_headless():
print(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return None, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.left:
print("← Left arrow: STARTING data recording...")
events.update({
"record_start": True,
"exit_early": False,
"save_data": False,
"discard_data": False
})
elif key == keyboard.Key.right:
print("→ Right arrow: SAVING current data...")
events.update({
"save_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.down:
print("↓ Down arrow: DISCARDING current data...")
events.update({
"discard_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.up:
print("↑ Up arrow: REPLAYING current data...")
events.update({
"replay_data": True,
"exit_early": False
})
elif key == keyboard.Key.esc:
print("ESC: EARLY TERMINATION requested")
events.update({
"exit_early": True,
"record_start": False
})
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events