multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

View File

@@ -0,0 +1,212 @@
from collections import deque
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
import datasets
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def check_final(
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
*,
# 索引与初始状态
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度max-min阈值
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2每步阈值度/步)
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”L2
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
) -> bool:
"""
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
"""
# --- 数据整理为 (N, D) tensor ---
if isinstance(last_states, torch.Tensor):
states = last_states
else:
states = torch.as_tensor(list(last_states), dtype=torch.float32)
if states.ndim != 2:
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
N, D = states.shape
if D < arm_dofs:
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
if N < 2:
return False # 样本太少,无法判定稳定
# 取最近窗口
w = min(N, stability_window)
window = states[-w:] # (w, D)
arm = window[:, :arm_dofs] # (w, 6)
last_arm = arm[-1] # (6,)
last_gripper = float(window[-1, gripper_index])
# --- 1) 最近 w 帧“没有太大变化” ---
# 两个指标每关节rangemax-min要小、相邻帧的平均“速度”要小
deg2rad = torch.pi / 180.0
range_tol = per_joint_range_deg * deg2rad
speed_tol = mean_speed_deg * deg2rad
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
max_range = float(ranges.abs().max()) # 标量
diffs = arm[1:] - arm[:-1] # (w-1, 6)
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
# --- 2) 夹爪关闭 ---
gripper_closed = (last_gripper < gripper_closed_thresh)
# --- 3) 末帧与“初始”差距要大 ---
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
if init.numel() != arm_dofs:
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
dist_from_init = float(torch.linalg.norm(last_arm - init))
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
# 组合判定
return bool(stable and gripper_closed and far_from_init)
# return bool(gripper_closed and far_from_init)
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
"""
Quickly fetch the last frame of each episode in a LeRobotDataset.
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
Returns: list[dict], where each element contains the last frame info of one episode.
"""
# 1) Compute the global index of the last row for each episode.
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
# 2) Determine which columns to load.
# By default, exclude video/image columns to avoid triggering slow video decoding.
if keys is None:
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
last_rows = ds.hf_dataset.select(end_idxs)
# 4) Build a dictionary of tensors for each requested key.
out = []
col = {k: last_rows[k] for k in keys}
# Convert lists of tensors into stacked tensors for easier indexing.
for k, v in col.items():
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
col[k] = torch.stack(v[:])
# Iterate through each episodes last frame and build a dict with its values.
for i, ep_end in enumerate(end_idxs):
item = {}
for k in keys:
val = col[k][i]
# Unpack 0-dimensional tensors into Python scalars.
if torch.is_tensor(val) and val.ndim == 0:
val = val.item()
item[k] = val
# Map task_index back to the human-readable task string.
if "task_index" in item:
item["task"] = ds.meta.tasks[int(item["task_index"])]
out.append(item)
# 5) Optionally decode the actual image/video frame for each last timestamp.
if include_images and len(ds.meta.video_keys) > 0:
for i, ep_end in enumerate(end_idxs):
ep_idx = int(out[i]["episode_index"])
ts = float(out[i]["timestamp"])
# Prepare a query dictionary: one timestamp per camera key.
query_ts = {k: [ts] for k in ds.meta.video_keys}
# Decode video frames at the specified timestamps for this episode.
frames = ds._query_videos(query_ts, ep_idx)
# Attach the decoded frame tensors to the output dictionary.
for k, v in frames.items():
out[i][k] = v
return out
if __name__ == "__main__":
# Initialize your dataset (replace with your repo ID or local path).
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
last_infos = get_last_frames(ds, include_images=False)
# Stack all 'observation.state' vectors into a single tensor for further processing.
states = torch.stack([info['observation.state'] for info in last_infos])
# Extract the left-arm joint states (first 7 values of each state vector).
left_arm_states = states[:, 0:7]
mean_state = torch.mean(left_arm_states, dim=0)
std_state = torch.std(left_arm_states, dim=0)
# Print the collected metadata for verification.
print(last_infos)
# --- Run check_final per episode using the last <=50 states ---
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
STABILITY_WINDOW = 120 # must be consistent with check_final's default
# Determine which episodes to iterate
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
episode_flags = {}
num_true, num_false = 0, 0
for ep_idx in episode_indices:
# Global index range [from_idx, to_idx) for this episode
from_idx = int(ds.episode_data_index["from"][ep_idx])
to_idx = int(ds.episode_data_index["to"][ep_idx])
if to_idx - from_idx <= 0:
episode_flags[ep_idx] = False
num_false += 1
continue
# Take the last <= STABILITY_WINDOW frames from this episode
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
rows = ds.hf_dataset.select(idxs)
# Collect full "observation.state" (shape ~ [W, S])
s_col = rows["observation.state"]
if isinstance(s_col, datasets.arrow_dataset.Column):
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
else:
S = torch.stack(s_col) # already a list[tensor]
# Build the 7D small state per frame: first 6 joints + gripper
# (Assumes the gripper signal is at the last position of the full state vector)
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
# Run your stopping logic
ok = check_final(
small_states,
arm_dofs=EP_ARM_DOFS,
gripper_index=-1,
stability_window=STABILITY_WINDOW,
)
episode_flags[ep_idx] = bool(ok)
num_true += int(ok)
num_false += int(not ok)
# Summary
total_eps = len(episode_indices)
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
# List some failed episodes for quick inspection
failed_eps = [e for e, passed in episode_flags.items() if not passed]
print("Failed episode indices (first 20):", failed_eps[:20])

View File

@@ -0,0 +1,88 @@
import os
import cv2
from pathlib import Path
from tqdm import tqdm
def extract_last_frame_from_videos(root_dir, output_dir, xx_last_frame=1):
"""
遍历目录找到所有images.rgb.hand_right视频文件提取最后一帧并保存
"""
# 查找所有mp4视频文件
video_files = []
for root, dirs, files in os.walk(root_dir):
for file in files:
if file.endswith('.mp4') and 'observation/head' in root:
video_files.append(os.path.join(root, file))
print(f"找到 {len(video_files)} 个视频文件")
# 处理每个视频文件
for video_path in tqdm(video_files):
try:
# 提取set名称和episode名称
path_parts = Path(video_path).parts
set_name = None
episode_name = None
for part in path_parts:
if part.startswith('set'):
set_name = part
if part.startswith('000'):
episode_name = part.replace('.mp4', '')
if not set_name or not episode_name:
print(f"无法从路径中提取set和episode信息: {video_path}")
continue
# 生成输出文件名
output_filename = f"{set_name}_{episode_name}.jpg"
output_path = os.path.join(output_dir, output_filename)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
continue
# 获取总帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
print(f"视频没有帧: {video_path}")
cap.release()
continue
# 跳转到最后一帧
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - xx_last_frame)
ret, frame = cap.read()
if ret:
# 保存最后一帧
cv2.imwrite(output_path, frame)
print(f"已保存:\n {output_path}")
else:
print(f"无法读取最后一帧: {video_path}")
# 释放资源
cap.release()
except Exception as e:
print(f"处理视频时出错 {video_path}: {str(e)}")
if __name__ == "__main__":
# 指定要遍历的根目录
root_directory = "/home/caijunhao/h-ceph/InternData-A1-raw/arx_lift2/Pick_the_industrial_components_from_the_conveyor" # 当前目录,您可以修改为您的目录路径
output_path = 'data/Pick_the_industrial_components_from_the_conveyor/'
os.makedirs(output_path, exist_ok=True)
sub_list = os.listdir(root_directory)
exclude_list = []
# exclude_list = [f"{i}" for i in range(16)] + [f"{i}" for i in range(26, 29)]
xx_last_frame = 1
# import pdb
# pdb.set_trace()
for sub in tqdm(sub_list):
if sub.split('-')[1].split('_')[0] in exclude_list:
continue
# print("os.path.join([root_directory, sub])\n", os.path.join(root_directory, sub))
extract_last_frame_from_videos(os.path.join(root_directory, sub), output_path, xx_last_frame=xx_last_frame)
print("处理完成!")

View File

@@ -0,0 +1,670 @@
# source /fs-computility/efm/liyang/miniconda3/etc/profile.d/conda.sh
# conda activate act
import argparse
import json
import logging
import os
import gc
import shutil
from concurrent.futures import ALL_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
import torchvision
import cv2
import h5py
import lmdb
import numpy as np
import pickle
import torch
from PIL import Image
from scipy.spatial.transform import Rotation
from tqdm import tqdm
import logging
import pdb
import os
import imageio # imageio-ffmpeg
from lerobot.common.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import check_timestamps_sync, get_episode_data_index, validate_episode_buffer
import time
# import ray
# from ray.runtime_env import RuntimeEnv
"""
Store both camera image and robot state as a combined observation.
Args:
observation: images(camera), states (robot state)
actions: joint, gripper, ee_pose
"""
FEATURES = {
"images.rgb.head": {
"dtype": "video",
"shape": (368, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_left": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_right": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
# "states.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "states.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "states.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "states.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0",],
# },
"observation.state": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
"action": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
# "actions.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "actions.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "actions.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "actions.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0", ],
# },
}
import numpy as np
def filter_forbidden_frames(state_dict, position_threshold=0.001, velocity_threshold=0.005):
"""
过滤禁止的帧,基于位置和速度阈值
参数:
- state_dict: 形状为 (n, 14) 的状态数组
- position_threshold: 位置变化的阈值
- velocity_threshold: 速度变化的阈值
返回:
- valid_mask: 布尔数组True表示有效帧
"""
# 排除夹爪列第6和第13列索引从0开始
qpos_columns = [i for i in range(14)]
qpos_data = state_dict[:, qpos_columns]
n_frames = len(state_dict)
valid_mask = np.ones(n_frames, dtype=bool)
# import pdb
# pdb.set_trace()
# 计算帧间差异(速度)
if n_frames > 1:
diff_sum = np.sum(np.abs(np.diff(qpos_data, axis=0)), axis=1)
# sorted_indices = np.argsort(diff_sum)[::-1]
# sorted_abs_sums = diff_sum[sorted_indices]
# velocities = np.diff(qpos_data, axis=0)
# 检查速度是否超过阈值
for i in range(n_frames - 1):
if np.any(np.abs(diff_sum[i]) > position_threshold):
valid_mask[i] = True # 有运动,有效帧
else:
valid_mask[i] = False # 静止,可能是禁止帧
valid_mask[i] = True
return valid_mask
def statistical_filter(state_dict, std_multiplier=2.0):
"""
使用统计方法检测异常(禁止)帧
"""
# 排除夹爪列
qpos_columns = [i for i in range(14) if i not in [6, 13]]
qpos_data = state_dict[:, qpos_columns]
# 计算每列的均值和标准差
means = np.mean(qpos_data, axis=0)
stds = np.std(qpos_data, axis=0)
# 创建有效掩码
valid_mask = np.ones(len(state_dict), dtype=bool)
for i in range(len(state_dict)):
# 检查每个关节位置是否在合理范围内
deviations = np.abs(qpos_data[i] - means)
if np.any(deviations > std_multiplier * stds):
valid_mask[i] = False # 异常帧
return valid_mask
class ARXDataset(LeRobotDataset):
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__(
repo_id=repo_id,
root=root,
episodes=episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
episode_buffer[key] = str(video_path) # PosixPath -> str
video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(videos[key], video_path)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self._save_episode_table(episode_buffer, episode_index)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
if not episode_data:
self.episode_buffer = self.create_episode_buffer()
def add_frame(self, frame: dict) -> None:
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
features = {key: value for key, value in self.features.items() if key in self.hf_features}
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
for key in frame:
if key == "task":
self.episode_buffer["task"].append(frame["task"])
continue
if key not in self.features:
print("key ", key)
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
# import pdb
# pdb.set_trace()
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
# def crop_resize_no_padding(image, target_size=(480, 640)):
# """
# Crop and scale to target size (no padding)
# :param image: input image (NumPy array)
# :param target_size: target size (height, width)
# :return: processed image
# """
# h, w = image.shape[:2]
# target_h, target_w = target_size
# target_ratio = target_w / target_h # Target aspect ratio (e.g. 640/480=1.333)
# # the original image aspect ratio and cropping direction
# if w / h > target_ratio: # Original image is wider → crop width
# crop_w = int(h * target_ratio) # Calculate crop width based on target aspect ratio
# crop_h = h
# start_x = (w - crop_w) // 2 # Horizontal center starting point
# start_y = 0
# else: # Original image is higher → crop height
# crop_h = int(w / target_ratio) # Calculate clipping height according to target aspect ratio
# crop_w = w
# start_x = 0
# start_y = (h - crop_h) // 2 # Vertical center starting point
# # Perform centered cropping (to prevent out-of-bounds)
# start_x, start_y = max(0, start_x), max(0, start_y)
# end_x, end_y = min(w, start_x + crop_w), min(h, start_y + crop_h)
# cropped = image[start_y:end_y, start_x:end_x]
# # Resize to target size (bilinear interpolation)
# resized = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# return resized
def load_lmdb_data(episode_path: Path, sava_path: Path, fps_factor: int, target_fps: int) -> Optional[Dict]:
def load_image(txn, key):
raw = txn.get(key)
data = pickle.loads(raw)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
# Convert to RGB if necessary
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = crop_resize_no_padding(image, target_size=(480, 640))
return image
try:
env = lmdb.open(
str(episode_path / "lmdb"),
readonly=True,
lock=False,
max_readers=128,
readahead=False
)
with env.begin(write=False) as txn:
keys = [k for k, _ in txn.cursor()]
image_keys = sorted([k for k in keys if b'head' in k])
if not image_keys:
return None
all_qpos = pickle.loads(txn.get(b'/observations/qpos'))
if np.isscalar(all_qpos):
total_steps = len(image_keys)
all_qpos = [all_qpos] * total_steps
else:
total_steps = len(all_qpos)
all_qpos = np.stack(all_qpos)
state_action_dict = {}
state_action_dict["states.left_joint.position"] = all_qpos[:, :6]
state_action_dict["states.left_gripper.position"] = all_qpos[:, 6][:, None] # np.expand_dims(all_qpos[:, 6], axis=1)
state_action_dict["states.right_joint.position"] = all_qpos[:, 7:13]
state_action_dict["states.right_gripper.position"] = all_qpos[:, 13][:, None] # np.expand_dims(all_qpos[:, 13], axis=1)
# state_keys = list(state_action_dict.keys())
# for k in state_keys:
# state_action_dict[k.replace("states", "actions")] = np.concatenate([state_action_dict[k][1:, :], state_action_dict[k][-1, :][None,:]], axis=0)
# action_dict = {}
# action_dict["actions.left_joint.position"] = np.concatenate([state_dict["states.left_joint.position"][1:, :], state_dict["states.left_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_joint.position"] = state_dict["states.right_joint.position"][1:, :]
# action_dict["actions.right_gripper.position"] = state_dict["states.right_gripper.position"][1:, :]
action_dict = {}
action_dict["action"] = np.concatenate([all_qpos[1:,], all_qpos[-1,].reshape(-1, 14)], axis=0)
state_dict = {}
state_dict["observation.state"] = all_qpos
mask1 = filter_forbidden_frames(state_dict["observation.state"])
# state_dict["observation.state"] = state_dict["observation.state"][mask1]
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_arm.position"] = np.concatenate([state_action_dict["states.right_joint.position"][1:, :], state_action_dict["states.right_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_arm.position"] = state_dict["states.right_gripper.position"][1:, :]
assert total_steps == len(image_keys), "qpos length mismatch"
selected_steps = [step for step in range(total_steps) if step % fps_factor == 0 and mask1[step]]
frames = []
image_observations = {
"images.rgb.head": [],
"images.rgb.hand_left": [],
"images.rgb.hand_right": []
}
start_time = time.time()
for step_index, step in enumerate(selected_steps):
step_str = f"{step:04d}"
head_key = f"observation/head/color_image/{step_str}".encode()
left_key = f"observation/left_wrist/color_image/{step_str}".encode()
right_key = f"observation/right_wrist/color_image/{step_str}".encode()
if not (head_key in keys and left_key in keys and right_key in keys):
continue
# state = all_qpos[step]
# if step_index < len(selected_steps) - 1:
# action = all_qpos[selected_steps[step_index + 1]]
# else:
# action = state
data_dict = {}
# for key, value in state_action_dict.items():
# data_dict[key] = value[step]
data_dict['action'] = action_dict["action"][step]
data_dict["task"] = " ".join(episode_path.parent.parent.name.split("_"))
data_dict['observation.state'] = state_dict["observation.state"][step]
# frames.append({
# "observation.states.joint.position": state,
# "actions.joint.position": action,
# "task": task_name,
# })
frames.append(data_dict)
image_observations["images.rgb.head"].append(load_image(txn, head_key))
image_observations["images.rgb.hand_left"].append(load_image(txn, left_key))
image_observations["images.rgb.hand_right"].append(load_image(txn, right_key))
end_time = time.time()
elapsed_time = end_time - start_time
print(f"load image_observations of {episode_path}")
env.close()
if not frames:
return None
os.makedirs(sava_path, exist_ok=True)
os.makedirs(sava_path/episode_path.name, exist_ok=True)
imageio.mimsave(sava_path/episode_path.name/'head.mp4', image_observations["images.rgb.head"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_left.mp4', image_observations["images.rgb.hand_left"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_right.mp4', image_observations["images.rgb.hand_right"], fps=target_fps)
print(f"imageio.mimsave time taken of {episode_path}")
return {
"frames": frames,
"videos": {
"images.rgb.head": sava_path/episode_path.name/"head.mp4",
"images.rgb.hand_left": sava_path/episode_path.name/"hand_left.mp4",
"images.rgb.hand_right": sava_path/episode_path.name/"hand_right.mp4",
},
}
except Exception as e:
logging.error(f"Failed to load LMDB data: {e}")
return None
def get_all_tasks(src_path: Path, output_path: Path) -> Tuple[Path, Path]:
src_dirs = sorted(list(src_path.glob("*"))) # "set*-*_collector*_datatime" as the conversion unit
save_dirs = [output_path/_dir.parent.name/_dir.name for _dir in src_dirs]
tasks_tuples = zip(src_dirs, save_dirs)
for task in tasks_tuples:
yield task
def compute_episode_stats(episode_data: Dict[str, List[str] | np.ndarray], features: Dict) -> Dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def sample_images(input):
if type(input) is str:
video_path = input
reader = torchvision.io.VideoReader(video_path, stream="video")
frames = [frame["data"] for frame in reader]
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
elif type(input) is np.ndarray:
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def load_local_dataset(episode_path: str, save_path:str, origin_fps=30, target_fps=30):
fps_factor = origin_fps // target_fps
# print(f"fps downsample factor: {fps_factor}")
# logging.info(f"fps downsample factor: {fps_factor}")
# for format_str in [f"{episode_id:07d}", f"{episode_id:06d}", str(episode_id)]:
# episode_path = Path(src_path) / format_str
# save_path = Path(save_path) / format_str
# if episode_path.exists():
# break
# else:
# logging.warning(f"Episode directory not found for ID {episode_id}")
# return None, None
episode_path = Path(episode_path)
if not episode_path.exists():
logging.warning(f"{episode_path} does not exist")
return None, None
if not (episode_path / "lmdb/data.mdb").exists():
logging.warning(f"LMDB data not found for episode {episode_path}")
return None, None
raw_dataset = load_lmdb_data(episode_path, save_path, fps_factor, target_fps)
if raw_dataset is None:
return None, None
frames = raw_dataset["frames"] # states, actions, task
videos = raw_dataset["videos"] # image paths
## check the frames
for camera_name, video_path in videos.items():
if not os.path.exists(video_path):
logging.error(f"Video file {video_path} does not exist.")
print(f"Camera {camera_name} Video file {video_path} does not exist.")
return None, None
return frames, videos
def save_as_lerobot_dataset(task: tuple[Path, Path], repo_id, num_threads, debug, origin_fps=30, target_fps=30, robot_type="piper", delete_downsampled_videos=True):
src_path, save_path = task
print(f"**Processing collected** {src_path}")
print(f"**saving to** {save_path}")
if save_path.exists():
# print(f"Output directory {save_path} already exists. Deleting it.")
# logging.warning(f"Output directory {save_path} already exists. Deleting it.")
# shutil.rmtree(save_path)
print(f"Output directory {save_path} already exists.")
return
dataset = ARXDataset.create(
repo_id=f"{repo_id}",
root=save_path,
fps=target_fps,
robot_type=robot_type,
features=FEATURES,
)
all_episode_paths = sorted([f.as_posix() for f in src_path.glob(f"*") if f.is_dir()])
# all_subdir_eids = [int(Path(path).name) for path in all_subdir]
if debug:
for i in range(1):
# pdb.set_trace()
frames, videos = load_local_dataset(episode_path=all_episode_paths[i], save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
if delete_downsampled_videos:
for _, video_path in videos.items():
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
# os.remove(video_path)
# print(f"Successfully deleted: {parent_dir}")
print(f"Successfully deleted: {video_path}")
except Exception as e:
pass # Handle the case where the directory might not exist or is already deleted
else:
for batch_index in range(len(all_episode_paths)//num_threads+1):
batch_episode_paths = all_episode_paths[batch_index*num_threads:(batch_index+1)*num_threads]
if len(batch_episode_paths) == 0:
continue
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for episode_path in batch_episode_paths:
print("starting to process episode: ", episode_path)
futures.append(
executor.submit(load_local_dataset, episode_path=episode_path, save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
)
for raw_dataset in as_completed(futures):
frames, videos = raw_dataset.result()
if frames is None or videos is None:
print(f"Skipping episode {episode_path} due to missing data.")
continue
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
gc.collect()
print(f"finishing processed {videos}")
if delete_downsampled_videos:
for _, video_path in videos.items():
# Get the parent directory of the video
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
print(f"Successfully deleted: {parent_dir}")
except Exception as e:
pass
def main(src_path, save_path, repo_id, num_threads=60, debug=False, origin_fps=30, target_fps=30):
logging.info("Scanning for episodes...")
tasks = get_all_tasks(src_path, save_path)
# import pdb
# pdb.set_trace()
if debug:
task = next(tasks)
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
else:
for task in tasks:
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert collected data from Piper to Lerobot format.")
parser.add_argument(
"--src_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/",
help="Path to the input file containing collected data in Piper format.",
#help="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Make_a_beef_sandwich",
)
parser.add_argument(
"--save_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/lerobot_v2_1/agilex_split_aloha/",
help="Path to the output file where the converted Lerobot format will be saved.",
#help="Path to the output file where the converted Lerobot format will be saved.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Run in debug mode with limited episodes",
)
parser.add_argument(
"--num-threads",
type=int,
default=50,
help="Number of threads per process",
)
# parser.add_argument(
# "--task_name",
# type=str,
# required=True,
# default="Pick_up_the_marker_and_put_it_into_the_pen_holder",
# help="Name of the task to be processed. Default is 'Pick_up_the_marker_and_put_it_into_the_pen_holder'.",
# )
parser.add_argument(
"--repo_id",
type=str,
required=True,
# default="SplitAloha_20250714",
help="identifier for the dataset repository.",
)
parser.add_argument(
"--origin_fps",
type=int,
default=30,
help="Frames per second for the obervation video. Default is 30.",
)
parser.add_argument(
"--target_fps",
type=int,
default=30,
help="Frames per second for the downsample video. Default is 30.",
)
args = parser.parse_args()
assert int(args.origin_fps) % int(args.target_fps) == 0, "origin_fps must be an integer multiple of target_fps"
start_time = time.time()
main(
src_path=Path(args.src_path),
save_path=Path(args.save_path),
repo_id=args.repo_id,
num_threads=args.num_threads,
debug=args.debug,
origin_fps=args.origin_fps,
target_fps=args.target_fps
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Total time taken: {elapsed_time:.2f} seconds")
# --target_fps 10
# --src_path /fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Put_the_bananas_in_the_basket
# --save_path /mnt/shared-storage-user/internvla/Users/liyang/data/processed_data/arx_lift2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff