Files
issacdataengine/policy/openpi-InternData-A1/examples/arx/merge_lerobot_data_v2.py
2026-03-17 23:05:23 +08:00

1510 lines
69 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import contextlib
import json
import os
import shutil
import traceback
import numpy as np
import pandas as pd
from termcolor import colored
def load_jsonl(file_path):
"""
从JSONL文件加载数据
(Load data from a JSONL file)
Args:
file_path (str): JSONL文件路径 (Path to the JSONL file)
Returns:
list: 包含文件中每行JSON对象的列表 (List containing JSON objects from each line)
"""
data = []
# Special handling for episodes_stats.jsonl
if "episodes_stats.jsonl" in file_path:
try:
# Try to load the entire file as a JSON array
with open(file_path) as f:
content = f.read()
# Check if the content starts with '[' and ends with ']'
if content.strip().startswith("[") and content.strip().endswith("]"):
return json.loads(content)
else:
# Try to add brackets and parse
try:
return json.loads("[" + content + "]")
except json.JSONDecodeError:
pass
except Exception as e:
print(f"Error loading {file_path} as JSON array: {e}")
# Fall back to line-by-line parsing
try:
with open(file_path) as f:
for line in f:
if line.strip():
with contextlib.suppress(json.JSONDecodeError):
data.append(json.loads(line))
except Exception as e:
print(f"Error loading {file_path} line by line: {e}")
else:
# Standard JSONL parsing for other files
with open(file_path) as f:
for line in f:
if line.strip():
with contextlib.suppress(json.JSONDecodeError):
data.append(json.loads(line))
return data
def save_jsonl(data, file_path):
"""
将数据保存为JSONL格式
(Save data in JSONL format)
Args:
data (list): 要保存的JSON对象列表 (List of JSON objects to save)
file_path (str): 输出文件路径 (Path to the output file)
"""
with open(file_path, "w") as f:
for item in data:
f.write(json.dumps(item) + "\n")
def merge_stats(stats_list):
"""
合并多个数据集的统计信息,确保维度一致性
(Merge statistics from multiple datasets, ensuring dimensional consistency)
Args:
stats_list (list): 包含每个数据集统计信息的字典列表
(List of dictionaries containing statistics for each dataset)
Returns:
dict: 合并后的统计信息 (Merged statistics)
"""
# Initialize merged stats with the structure of the first stats
merged_stats = {}
# Find common features across all stats
common_features = set(stats_list[0].keys())
for stats in stats_list[1:]:
common_features = common_features.intersection(set(stats.keys()))
# Process features in the order they appear in the first stats file
for feature in stats_list[0]:
if feature not in common_features:
continue
merged_stats[feature] = {}
# Find common stat types for this feature
common_stat_types = []
for stat_type in ["mean", "std", "max", "min"]:
if all(stat_type in stats[feature] for stats in stats_list):
common_stat_types.append(stat_type)
# Determine the original shape of each value
original_shapes = []
for stats in stats_list:
if "mean" in stats[feature]:
shape = np.array(stats[feature]["mean"]).shape
original_shapes.append(shape)
# Special handling for image features to preserve nested structure
if feature.startswith("observation.images."):
for stat_type in common_stat_types:
try:
# Get all values
values = [stats[feature][stat_type] for stats in stats_list]
# For image features, we need to preserve the nested structure
# Initialize with the first value's structure
result = []
# For RGB channels
for channel_idx in range(len(values[0])):
channel_result = []
# For each pixel row
for pixel_idx in range(len(values[0][channel_idx])):
pixel_result = []
# For each pixel value
for value_idx in range(len(values[0][channel_idx][pixel_idx])):
# Calculate statistic based on type
if stat_type == "mean":
# Simple average
avg = sum(
values[i][channel_idx][pixel_idx][value_idx]
for i in range(len(values))
) / len(values)
pixel_result.append(avg)
elif stat_type == "std":
# Simple average of std
avg = sum(
values[i][channel_idx][pixel_idx][value_idx]
for i in range(len(values))
) / len(values)
pixel_result.append(avg)
elif stat_type == "max":
# Maximum
max_val = max(
values[i][channel_idx][pixel_idx][value_idx]
for i in range(len(values))
)
pixel_result.append(max_val)
elif stat_type == "min":
# Minimum
min_val = min(
values[i][channel_idx][pixel_idx][value_idx]
for i in range(len(values))
)
pixel_result.append(min_val)
channel_result.append(pixel_result)
result.append(channel_result)
merged_stats[feature][stat_type] = result
except Exception as e:
print(f"Warning: Error processing image feature {feature}.{stat_type}: {e}")
# Fallback to first value
merged_stats[feature][stat_type] = values[0]
# If all shapes are the same, no need for special handling
elif len({str(shape) for shape in original_shapes}) == 1:
# All shapes are the same, use standard merging
for stat_type in common_stat_types:
values = [stats[feature][stat_type] for stats in stats_list]
try:
# Calculate the new statistic based on the type
if stat_type == "mean":
if all("count" in stats[feature] for stats in stats_list):
counts = [stats[feature]["count"][0] for stats in stats_list]
total_count = sum(counts)
weighted_values = [
np.array(val) * count / total_count
for val, count in zip(values, counts, strict=False)
]
merged_stats[feature][stat_type] = np.sum(weighted_values, axis=0).tolist()
else:
merged_stats[feature][stat_type] = np.mean(np.array(values), axis=0).tolist()
elif stat_type == "std":
if all("count" in stats[feature] for stats in stats_list):
counts = [stats[feature]["count"][0] for stats in stats_list]
total_count = sum(counts)
variances = [np.array(std) ** 2 for std in values]
weighted_variances = [
var * count / total_count
for var, count in zip(variances, counts, strict=False)
]
merged_stats[feature][stat_type] = np.sqrt(
np.sum(weighted_variances, axis=0)
).tolist()
else:
merged_stats[feature][stat_type] = np.mean(np.array(values), axis=0).tolist()
elif stat_type == "max":
merged_stats[feature][stat_type] = np.maximum.reduce(np.array(values)).tolist()
elif stat_type == "min":
merged_stats[feature][stat_type] = np.minimum.reduce(np.array(values)).tolist()
except Exception as e:
print(f"Warning: Error processing {feature}.{stat_type}: {e}")
continue
else:
# Shapes are different, need special handling for state vectors
if feature in ["observation.state", "action"]:
# For state vectors, we need to handle different dimensions
max_dim = max(len(np.array(stats[feature]["mean"]).flatten()) for stats in stats_list)
for stat_type in common_stat_types:
try:
# Get values and their original dimensions
values_with_dims = []
for stats in stats_list:
val = np.array(stats[feature][stat_type]).flatten()
dim = len(val)
values_with_dims.append((val, dim))
# Initialize result array with zeros
result = np.zeros(max_dim)
# Calculate statistics for each dimension separately
if stat_type == "mean":
if all("count" in stats[feature] for stats in stats_list):
counts = [stats[feature]["count"][0] for stats in stats_list]
total_count = sum(counts)
# For each dimension, calculate weighted mean of available values
for d in range(max_dim):
dim_values = []
dim_weights = []
for (val, dim), count in zip(values_with_dims, counts, strict=False):
if d < dim: # Only use values that have this dimension
dim_values.append(val[d])
dim_weights.append(count)
if dim_values: # If we have values for this dimension
weighted_sum = sum(
v * w for v, w in zip(dim_values, dim_weights, strict=False)
)
result[d] = weighted_sum / sum(dim_weights)
else:
# Simple average for each dimension
for d in range(max_dim):
dim_values = [val[d] for val, dim in values_with_dims if d < dim]
if dim_values:
result[d] = sum(dim_values) / len(dim_values)
elif stat_type == "std":
if all("count" in stats[feature] for stats in stats_list):
counts = [stats[feature]["count"][0] for stats in stats_list]
total_count = sum(counts)
# For each dimension, calculate weighted variance
for d in range(max_dim):
dim_variances = []
dim_weights = []
for (val, dim), count in zip(values_with_dims, counts, strict=False):
if d < dim: # Only use values that have this dimension
dim_variances.append(val[d] ** 2) # Square for variance
dim_weights.append(count)
if dim_variances: # If we have values for this dimension
weighted_var = sum(
v * w for v, w in zip(dim_variances, dim_weights, strict=False)
) / sum(dim_weights)
result[d] = np.sqrt(weighted_var) # Take sqrt for std
else:
# Simple average of std for each dimension
for d in range(max_dim):
dim_values = [val[d] for val, dim in values_with_dims if d < dim]
if dim_values:
result[d] = sum(dim_values) / len(dim_values)
elif stat_type == "max":
# For each dimension, take the maximum of available values
for d in range(max_dim):
dim_values = [val[d] for val, dim in values_with_dims if d < dim]
if dim_values:
result[d] = max(dim_values)
elif stat_type == "min":
# For each dimension, take the minimum of available values
for d in range(max_dim):
dim_values = [val[d] for val, dim in values_with_dims if d < dim]
if dim_values:
result[d] = min(dim_values)
# Convert result to list and store
merged_stats[feature][stat_type] = result.tolist()
except Exception as e:
print(
f"Warning: Error processing {feature}.{stat_type} with different dimensions: {e}"
)
continue
else:
# For other features with different shapes, use the first shape as template
template_shape = original_shapes[0]
print(f"Using shape {template_shape} as template for {feature}")
for stat_type in common_stat_types:
try:
# Use the first stats as template
merged_stats[feature][stat_type] = stats_list[0][feature][stat_type]
except Exception as e:
print(
f"Warning: Error processing {feature}.{stat_type} with shape {template_shape}: {e}"
)
continue
# Add count if available in all stats
if all("count" in stats[feature] for stats in stats_list):
try:
merged_stats[feature]["count"] = [sum(stats[feature]["count"][0] for stats in stats_list)]
except Exception as e:
print(f"Warning: Error processing {feature}.count: {e}")
return merged_stats
def copy_videos(source_folders, output_folder, episode_mapping):
"""
从源文件夹复制视频文件到输出文件夹,保持正确的索引和结构
(Copy video files from source folders to output folder, maintaining correct indices and structure)
Args:
source_folders (list): 源数据集文件夹路径列表 (List of source dataset folder paths)
output_folder (str): 输出文件夹路径 (Output folder path)
episode_mapping (list): 包含(旧文件夹,旧索引,新索引)元组的列表
(List of tuples containing (old_folder, old_index, new_index))
"""
# Get info.json to determine video structure
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
video_path_template = info["video_path"]
# Identify video keys from the template
# Example: "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
video_keys = []
for feature_name, feature_info in info["features"].items():
if feature_info.get("dtype") == "video":
# Use the full feature name as the video key
video_keys.append(feature_name)
print(f"Found video keys: {video_keys}")
# Copy videos for each episode
for old_folder, old_index, new_index in episode_mapping:
# Determine episode chunk (usually 0 for small datasets)
episode_chunk = old_index // info["chunks_size"]
new_episode_chunk = new_index // info["chunks_size"]
for video_key in video_keys:
# Try different possible source paths
source_patterns = [
# Standard path with the episode index from metadata
os.path.join(
old_folder,
video_path_template.format(
episode_chunk=episode_chunk, video_key=video_key, episode_index=old_index
),
),
# Try with 0-based indexing
os.path.join(
old_folder,
video_path_template.format(episode_chunk=0, video_key=video_key, episode_index=0),
),
# Try with different formatting
os.path.join(
old_folder, f"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{old_index}.mp4"
),
os.path.join(old_folder, f"videos/chunk-000/{video_key}/episode_000000.mp4"),
]
# Find the first existing source path
source_video_path = None
for pattern in source_patterns:
if os.path.exists(pattern):
source_video_path = pattern
break
if source_video_path:
# Construct destination path
dest_video_path = os.path.join(
output_folder,
video_path_template.format(
episode_chunk=new_episode_chunk, video_key=video_key, episode_index=new_index
),
)
# Create destination directory if it doesn't exist
os.makedirs(os.path.dirname(dest_video_path), exist_ok=True)
print(f"Copying video: {source_video_path} -> {dest_video_path}")
shutil.copy2(source_video_path, dest_video_path)
else:
# If no file is found, search the directory recursively
found = False
for root, _, files in os.walk(os.path.join(old_folder, "videos")):
for file in files:
if file.endswith(".mp4") and video_key in root:
source_video_path = os.path.join(root, file)
# Construct destination path
dest_video_path = os.path.join(
output_folder,
video_path_template.format(
episode_chunk=new_episode_chunk,
video_key=video_key,
episode_index=new_index,
),
)
# Create destination directory if it doesn't exist
os.makedirs(os.path.dirname(dest_video_path), exist_ok=True)
print(
f"Copying video (found by search): {source_video_path} -> {dest_video_path}"
)
shutil.copy2(source_video_path, dest_video_path)
found = True
break
if found:
break
if not found:
print(
f"Warning: Video file not found for {video_key}, episode {old_index} in {old_folder}"
)
def validate_timestamps(source_folders, tolerance_s=1e-4):
"""
验证源数据集的时间戳结构,识别潜在问题
(Validate timestamp structure of source datasets, identify potential issues)
Args:
source_folders (list): 源数据集文件夹路径列表 (List of source dataset folder paths)
tolerance_s (float): 时间戳不连续性的容差值,以秒为单位 (Tolerance for timestamp discontinuities in seconds)
Returns:
tuple: (issues, fps_values) - 问题列表和检测到的FPS值列表
(List of issues and list of detected FPS values)
"""
issues = []
fps_values = []
for folder in source_folders:
try:
# 尝试从 info.json 获取 FPS (Try to get FPS from info.json)
info_path = os.path.join(folder, "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
if "fps" in info:
fps = info["fps"]
fps_values.append(fps)
print(f"数据集 {folder} FPS={fps} (Dataset {folder} FPS={fps})")
# 检查是否有parquet文件包含时间戳 (Check if any parquet files contain timestamps)
parquet_path = None
for root, _, files in os.walk(os.path.join(folder, "parquet")):
for file in files:
if file.endswith(".parquet"):
parquet_path = os.path.join(root, file)
break
if parquet_path:
break
if not parquet_path:
for root, _, files in os.walk(os.path.join(folder, "data")):
for file in files:
if file.endswith(".parquet"):
parquet_path = os.path.join(root, file)
break
if parquet_path:
break
if parquet_path:
df = pd.read_parquet(parquet_path)
timestamp_cols = [col for col in df.columns if "timestamp" in col or "time" in col]
if timestamp_cols:
print(
f"数据集 {folder} 包含时间戳列: {timestamp_cols} (Dataset {folder} contains timestamp columns: {timestamp_cols})"
)
else:
issues.append(
f"警告: 数据集 {folder} 没有时间戳列 (Warning: Dataset {folder} has no timestamp columns)"
)
else:
issues.append(
f"警告: 数据集 {folder} 未找到parquet文件 (Warning: No parquet files found in dataset {folder})"
)
except Exception as e:
issues.append(
f"错误: 验证数据集 {folder} 失败: {e} (Error: Failed to validate dataset {folder}: {e})"
)
print(f"验证错误: {e} (Validation error: {e})")
traceback.print_exc()
# 检查FPS是否一致 (Check if FPS values are consistent)
if len(set(fps_values)) > 1:
issues.append(
f"警告: 数据集FPS不一致: {fps_values} (Warning: Inconsistent FPS across datasets: {fps_values})"
)
return issues, fps_values
def copy_data_files(
source_folders,
output_folder,
episode_mapping,
fps=None,
episode_to_frame_index=None,
folder_task_mapping=None,
chunks_size=1000,
default_fps=20,
):
"""
从源文件夹复制数据文件到输出文件夹,同时处理索引映射和维度填充
(Copy data files from source folders to output folder, handling index mapping and dimension padding)
Args:
source_folders (list): 源数据集文件夹路径列表 (List of source dataset folder paths)
output_folder (str): 输出文件夹路径 (Output folder path)
episode_mapping (list): 包含(旧文件夹,旧索引,新索引)元组的列表
(List of tuples containing (old_folder, old_index, new_index))
fps (float): 帧率 (frames per second)
episode_to_frame_index (dict): 每个episode对应的起始帧索引
(Start frame index for each episode)
folder_task_mapping (dict): 文件夹任务映射 (Folder task mapping)
chunks_size (int): 数据块大小 (Chunk size)
default_fps (float): 默认帧率 (Default frame rate)
"""
# 获取第一个数据集的FPS如果未提供(Get FPS from first dataset if not provided)
if fps is None:
info_path = os.path.join(source_folders[0], "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
fps = info.get(
"fps", default_fps
) # 使用变量替代硬编码的20 (Use variable instead of hardcoded 20)
else:
fps = default_fps # 使用变量替代硬编码的20 (Use variable instead of hardcoded 20)
print(f"使用FPS={fps}")
# 为每个episode复制和处理数据文件 (Copy and process data files for each episode)
total_copied = 0
total_failed = 0
# 添加一个列表来记录失败的文件及原因
# (Add a list to record failed files and reasons)
failed_files = []
for i, (old_folder, old_index, new_index) in enumerate(episode_mapping):
# 尝试找到源parquet文件 (Try to find source parquet file)
episode_str = f"episode_{old_index:06d}.parquet"
source_paths = [
os.path.join(old_folder, "parquet", episode_str),
os.path.join(old_folder, "data", episode_str),
]
source_path = None
for path in source_paths:
if os.path.exists(path):
source_path = path
break
if source_path:
try:
# 读取parquet文件 (Read parquet file)
df = pd.read_parquet(source_path)
# 更新episode_index列 (Update episode_index column)
if "episode_index" in df.columns:
print(
f"更新episode_index从 {df['episode_index'].iloc[0]}{new_index} (Update episode_index from {df['episode_index'].iloc[0]} to {new_index})"
)
df["episode_index"] = new_index
# 更新index列 (Update index column)
if "index" in df.columns:
if episode_to_frame_index and new_index in episode_to_frame_index:
# 使用预先计算的帧索引起始值 (Use pre-calculated frame index start value)
first_index = episode_to_frame_index[new_index]
print(
f"更新index列起始值: {first_index}(使用全局累积帧计数)(Update index column, start value: {first_index} (using global cumulative frame count))"
)
else:
# 如果没有提供映射,使用当前的计算方式作为回退
# (If no mapping provided, use current calculation as fallback)
first_index = new_index * len(df)
print(
f"更新index列起始值: {first_index}使用episode索引乘以长度(Update index column, start value: {first_index} (using episode index multiplied by length))"
)
# 更新所有帧的索引 (Update indices for all frames)
df["index"] = [first_index + i for i in range(len(df))]
# 更新task_index列 (Update task_index column)
if "task_index" in df.columns and folder_task_mapping and old_folder in folder_task_mapping:
# 获取当前task_index (Get current task_index)
current_task_index = df["task_index"].iloc[0]
# 检查是否有对应的新索引 (Check if there's a corresponding new index)
if current_task_index in folder_task_mapping[old_folder]:
new_task_index = folder_task_mapping[old_folder][current_task_index]
print(
f"更新task_index从 {current_task_index}{new_task_index} (Update task_index from {current_task_index} to {new_task_index})"
)
df["task_index"] = new_task_index
else:
print(
f"警告: 找不到task_index {current_task_index}的映射关系 (Warning: No mapping found for task_index {current_task_index})"
)
# 计算chunk编号 (Calculate chunk number)
chunk_index = new_index // chunks_size
# 创建正确的目标目录 (Create correct target directory)
chunk_dir = os.path.join(output_folder, "data", f"chunk-{chunk_index:03d}")
os.makedirs(chunk_dir, exist_ok=True)
# 构建正确的目标路径 (Build correct target path)
dest_path = os.path.join(chunk_dir, f"episode_{new_index:06d}.parquet")
# 保存到正确位置 (Save to correct location)
df.to_parquet(dest_path, index=False)
total_copied += 1
print(f"已处理并保存: {dest_path} (Processed and saved: {dest_path})")
except Exception as e:
error_msg = f"处理 {source_path} 失败: {e} (Processing {source_path} failed: {e})"
print(error_msg)
traceback.print_exc()
failed_files.append({"file": source_path, "reason": str(e), "episode": old_index})
total_failed += 1
else:
# 文件不在标准位置,尝试递归搜索
found = False
for root, _, files in os.walk(old_folder):
for file in files:
if file.endswith(".parquet") and f"episode_{old_index:06d}" in file:
try:
source_path = os.path.join(root, file)
# 读取parquet文件 (Read parquet file)
df = pd.read_parquet(source_path)
# 更新episode_index列 (Update episode_index column)
if "episode_index" in df.columns:
print(
f"更新episode_index从 {df['episode_index'].iloc[0]}{new_index} (Update episode_index from {df['episode_index'].iloc[0]} to {new_index})"
)
df["episode_index"] = new_index
# 更新index列 (Update index column)
if "index" in df.columns:
if episode_to_frame_index and new_index in episode_to_frame_index:
# 使用预先计算的帧索引起始值 (Use pre-calculated frame index start value)
first_index = episode_to_frame_index[new_index]
print(
f"更新index列起始值: {first_index}(使用全局累积帧计数)(Update index column, start value: {first_index} (using global cumulative frame count))"
)
else:
# 如果没有提供映射,使用当前的计算方式作为回退
# (If no mapping provided, use current calculation as fallback)
first_index = new_index * len(df)
print(
f"更新index列起始值: {first_index}使用episode索引乘以长度(Update index column, start value: {first_index} (using episode index multiplied by length))"
)
# 更新所有帧的索引 (Update indices for all frames)
df["index"] = [first_index + i for i in range(len(df))]
# 更新task_index列 (Update task_index column)
if (
"task_index" in df.columns
and folder_task_mapping
and old_folder in folder_task_mapping
):
# 获取当前task_index (Get current task_index)
current_task_index = df["task_index"].iloc[0]
# 检查是否有对应的新索引 (Check if there's a corresponding new index)
if current_task_index in folder_task_mapping[old_folder]:
new_task_index = folder_task_mapping[old_folder][current_task_index]
print(
f"更新task_index从 {current_task_index}{new_task_index} (Update task_index from {current_task_index} to {new_task_index})"
)
df["task_index"] = new_task_index
else:
print(
f"警告: 找不到task_index {current_task_index}的映射关系 (Warning: No mapping found for task_index {current_task_index})"
)
# 计算chunk编号 (Calculate chunk number)
chunk_index = new_index // chunks_size
# 创建正确的目标目录 (Create correct target directory)
chunk_dir = os.path.join(output_folder, "data", f"chunk-{chunk_index:03d}")
os.makedirs(chunk_dir, exist_ok=True)
# 构建正确的目标路径 (Build correct target path)
dest_path = os.path.join(chunk_dir, f"episode_{new_index:06d}.parquet")
# 保存到正确位置 (Save to correct location)
df.to_parquet(dest_path, index=False)
total_copied += 1
found = True
print(f"已处理并保存: {dest_path} (Processed and saved: {dest_path})")
break
except Exception as e:
error_msg = f"处理 {source_path} 失败: {e} (Processing {source_path} failed: {e})"
print(error_msg)
traceback.print_exc()
failed_files.append({"file": source_path, "reason": str(e), "episode": old_index})
total_failed += 1
if found:
break
if not found:
error_msg = f"找不到episode {old_index}的parquet文件源文件夹: {old_folder}"
print(error_msg)
failed_files.append(
{"file": f"episode_{old_index:06d}.parquet", "reason": "文件未找到", "folder": old_folder}
)
total_failed += 1
print(f"共复制 {total_copied} 个数据文件,{total_failed} 个失败")
# 打印所有失败的文件详情 (Print details of all failed files)
if failed_files:
print("\n失败的文件详情 (Details of failed files):")
for i, failed in enumerate(failed_files):
print(f"{i + 1}. 文件 (File): {failed['file']}")
if "folder" in failed:
print(f" 文件夹 (Folder): {failed['folder']}")
if "episode" in failed:
print(f" Episode索引 (Episode index): {failed['episode']}")
print(f" 原因 (Reason): {failed['reason']}")
print("---")
return total_copied > 0
def pad_parquet_data(source_path, target_path, original_dim=14, target_dim=18):
"""
通过零填充将parquet数据从原始维度扩展到目标维度
(Extend parquet data from original dimension to target dimension by zero-padding)
Args:
source_path (str): 源parquet文件路径 (Source parquet file path)
target_path (str): 目标parquet文件路径 (Target parquet file path)
original_dim (int): 原始向量维度 (Original vector dimension)
target_dim (int): 目标向量维度 (Target vector dimension)
"""
# 读取parquet文件
df = pd.read_parquet(source_path)
# 打印列名以便调试
print(f"Columns in {source_path}: {df.columns.tolist()}")
# 创建新的DataFrame来存储填充后的数据
new_df = df.copy()
# 检查observation.state和action列是否存在
if "observation.state" in df.columns:
# 检查第一行数据,确认是否为向量
first_state = df["observation.state"].iloc[0]
print(f"First observation.state type: {type(first_state)}, value: {first_state}")
# 如果是向量列表或numpy数组
if isinstance(first_state, (list, np.ndarray)):
# 检查维度
state_dim = len(first_state)
print(f"observation.state dimension: {state_dim}")
if state_dim < target_dim:
# 填充向量
print(f"Padding observation.state from {state_dim} to {target_dim} dimensions")
new_df["observation.state"] = df["observation.state"].apply(
lambda x: np.pad(x, (0, target_dim - len(x)), "constant").tolist()
)
# 同样处理action列
if "action" in df.columns:
# 检查第一行数据
first_action = df["action"].iloc[0]
print(f"First action type: {type(first_action)}, value: {first_action}")
# 如果是向量
if isinstance(first_action, (list, np.ndarray)):
# 检查维度
action_dim = len(first_action)
print(f"action dimension: {action_dim}")
if action_dim < target_dim:
# 填充向量
print(f"Padding action from {action_dim} to {target_dim} dimensions")
new_df["action"] = df["action"].apply(
lambda x: np.pad(x, (0, target_dim - len(x)), "constant").tolist()
)
# 确保目标目录存在
os.makedirs(os.path.dirname(target_path), exist_ok=True)
# 保存到新的parquet文件
new_df.to_parquet(target_path, index=False)
print(f"已将{source_path}处理并保存到{target_path}")
return new_df
def count_video_frames_torchvision(video_path):
"""
Count the number of frames in a video file using torchvision
Args:
video_path (str):
Returns:
Frame count (int):
"""
try:
import torchvision
# Ensure torchvision version is recent enough for VideoReader and AV1 support
# (This is a general good practice, specific version checks might be needed
# depending on the exact AV1 library used by torchvision's backend)
# print(f"Torchvision version: {torchvision.__version__}")
# print(f"PyTorch version: {torch.__version__}")
# VideoReader requires the video path as a string
reader = torchvision.io.VideoReader(video_path, "video")
# Attempt to get frame count from metadata
# Metadata structure can vary; "video" stream usually has "num_frames"
metadata = reader.get_metadata()
frame_count = 0
if "video" in metadata and "num_frames" in metadata["video"] and len(metadata["video"]["num_frames"]) > 0:
# num_frames is often a list, take the first element
frame_count = int(metadata["video"]["num_frames"][0])
if frame_count > 0:
# If metadata provides a positive frame count, we can often trust it.
# For some backends/formats, this might be the most reliable way.
return frame_count
# If metadata didn't provide a reliable frame count, or to be absolutely sure,
# we can iterate through the frames.
# This is more robust but potentially slower.
count_manually = 0
for _ in reader: # Iterating through the reader yields frames
count_manually += 1
# If manual count is zero but metadata had a count, it might indicate an issue
# or an empty video. Prioritize manual count if it's > 0.
if count_manually > 0:
return count_manually
elif frame_count > 0 : # Fallback to metadata if manual count was 0 but metadata had a value
print(f"Warning: Manual count is 0, but metadata indicates {frame_count} frames. Video might be empty or there was a read issue. Returning metadata count.")
return frame_count
else:
# This case means both metadata (if available) and manual iteration yielded 0.
print(f"Video appears to have no frames: {video_path}")
return 0
except ImportError:
print("Warning: torchvision or its dependencies (like ffmpeg) not installed, cannot count video frames")
return 0
except RuntimeError as e:
# RuntimeError can be raised by VideoReader for various issues (e.g., file not found, corrupt file, unsupported codec by the backend)
if "No video stream found" in str(e):
print(f"Error: No video stream found in video file: {video_path}")
elif "Could not open" in str(e) or "Demuxing video" in str(e):
print(f"Error: Could not open or demux video file (possibly unsupported format or corrupted file): {video_path} - {e}")
else:
print(f"Runtime error counting video frames: {e}")
return 0
except Exception as e:
print(f"Error counting video frames: {e}")
return 0
finally:
# VideoReader does not have an explicit close() or release() method.
# It's managed by its destructor when it goes out of scope.
pass
def early_validation(source_folders, episode_mapping, default_fps=20, fps=None):
"""
Validate and copy image files from source folders to output folder.
Performs validation first before any copying to ensure dataset consistency.
Args:
source_folders (list): List of source dataset folder paths
output_folder (str): Output folder path
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
default_fps (int): Default frame rate to use if not specified
fps (int): Frame rate to use for video encoding
Returns:
dict: Validation results containing expected frame count and actual image count for each episode
"""
if fps is None:
info_path = os.path.join(source_folders[0], "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
fps = info.get("fps", default_fps)
else:
fps = default_fps
print(f"Using FPS={fps}")
# Get video path template and video keys
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
video_path_template = info["video_path"]
image_keys = []
for feature_name, feature_info in info["features"].items():
if feature_info.get("dtype") == "video":
image_keys.append(feature_name)
print(f"Found video/image keys: {image_keys}")
# Validate first before copying anything
print("Starting validation of images and videos...")
validation_results = {}
validation_failed = False
episode_file_mapping = {}
for old_folder, old_index, new_index in episode_mapping:
# Get expected frame count from episodes.jsonl
episode_file = os.path.join(old_folder, "meta", "episodes.jsonl")
expected_frames = 0
if os.path.exists(episode_file):
if episode_file not in episode_file_mapping:
episodes = load_jsonl(episode_file)
episodes = {ep["episode_index"]: ep for ep in episodes}
episode_file_mapping[episode_file] = episodes
episode_data = episode_file_mapping[episode_file].get(old_index, None)
if episode_data and "length" in episode_data:
expected_frames = episode_data["length"]
validation_key = f"{old_folder}_{old_index}"
validation_results[validation_key] = {
"expected_frames": expected_frames,
"image_counts": {},
"video_frames": {},
"old_index": old_index,
"new_index": new_index,
"is_valid": True # Default to valid
}
# Check each image directory and video
episode_chunk = old_index // info["chunks_size"]
for image_dir in image_keys:
# Find the video file
source_video_path = os.path.join(
old_folder,
video_path_template.format(
episode_chunk=episode_chunk, video_key=image_dir, episode_index=old_index
),
)
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
image_dir_exists = os.path.exists(source_image_dir)
video_file_exists = os.path.exists(source_video_path)
if not video_file_exists:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video file not found for {image_dir}, episode {old_index} in {old_folder}")
if image_dir_exists:
print(" Image directory exists, encoding video from images.")
from lerobot.common.datasets.video_utils import encode_video_frames
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
print(" Encoded video frames successfully.")
else:
print(f"{colored('ERROR', 'red', attrs=['bold'])}: No video or image directory found for {image_dir}, episode {old_index} in {old_folder}")
validation_results[validation_key]["is_valid"] = False
validation_failed = True
continue
# Count video frames
video_frame_count = count_video_frames_torchvision(source_video_path)
validation_results[validation_key]["video_frames"][image_dir] = video_frame_count
# Check if image directory exists
if image_dir_exists:
# Count image files
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
images_count = len(image_files)
validation_results[validation_key]["image_counts"][image_dir] = images_count
error_msg = f"expected_frames: {expected_frames}, images_count: {images_count}, video_frame_count: {video_frame_count}"
assert expected_frames > 0 and expected_frames == images_count, (
f"{colored('ERROR', 'red', attrs=['bold'])}: Image count should match expected frames for {source_image_dir}.\n {error_msg}"
)
assert expected_frames >= video_frame_count, (
f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count should be less or equal than expected frames for {source_video_path}.\n {error_msg}"
)
# Validate frame counts
if video_frame_count != expected_frames:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
print(f" Re-encoded video frames from {source_image_dir} to {source_video_path}")
from lerobot.common.datasets.video_utils import encode_video_frames
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
print(" Re-encoded video frames successfully.")
else:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: No image directory {image_dir} found for episode {old_index} in {old_folder}")
print(" You can ignore this if you are not using images and your video frame count is equal to expected frames.")
# If no images directory, the video frames must match expected frames
if expected_frames > 0 and video_frame_count != expected_frames:
print(f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
validation_results[validation_key]["is_valid"] = False
validation_failed = True
# Print validation summary
print("\nValidation Results:")
valid_count = sum(1 for result in validation_results.values() if result["is_valid"])
print(f"{valid_count} of {len(validation_results)} episodes are valid")
# If validation failed, stop the process
if validation_failed:
print(colored("Validation failed. Please fix the issues before continuing.", "red", attrs=["bold"]))
def copy_images(source_folders, output_folder, episode_mapping, default_fps=20, fps=None):
"""
Copy image files from source folders to output folder.
This function assumes validation has already been performed with early_validation().
Args:
source_folders (list): List of source dataset folder paths
output_folder (str): Output folder path
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
default_fps (int): Default frame rate to use if not specified
fps (int): Frame rate to use for video encoding
Returns:
int: Number of images copied
"""
if fps is None:
info_path = os.path.join(source_folders[0], "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
fps = info.get("fps", default_fps)
else:
fps = default_fps
# Get video path template and video keys
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
video_path_template = info["video_path"]
image_keys = []
for feature_name, feature_info in info["features"].items():
if feature_info.get("dtype") == "video":
image_keys.append(feature_name)
# Create image directories in output folder
os.makedirs(os.path.join(output_folder, "images"), exist_ok=True)
print(f"Starting to copy images for {len(image_keys)} video keys...")
total_copied = 0
skipped_episodes = 0
# Copy images for each episode
for old_folder, old_index, new_index in episode_mapping:
episode_chunk = old_index // info["chunks_size"]
new_episode_chunk = new_index // info["chunks_size"]
episode_copied = False
for image_dir in image_keys:
# Create target directory for this video key
os.makedirs(os.path.join(output_folder, "images", image_dir), exist_ok=True)
# Check if source image directory exists
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
if os.path.exists(source_image_dir):
# Create target directory
target_image_dir = os.path.join(output_folder, "images", image_dir, f"episode_{new_index:06d}")
os.makedirs(target_image_dir, exist_ok=True)
# Copy image files
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
num_images = len(image_files)
if num_images > 0:
print(f"Copying {num_images} images from {source_image_dir} to {target_image_dir}")
for image_file in image_files:
try:
# Extract frame number from filename
frame_part = image_file.split('_')[1] if '_' in image_file else image_file
frame_num = int(frame_part.split('.')[0])
# Copy the file with consistent naming
dest_file = os.path.join(target_image_dir, f"frame_{frame_num:06d}.png")
shutil.copy2(
os.path.join(source_image_dir, image_file),
dest_file
)
total_copied += 1
episode_copied = True
except Exception as e:
print(f"Error copying image {image_file}: {e}")
if not episode_copied:
skipped_episodes += 1
print(f"\nCopied {total_copied} images for {len(episode_mapping) - skipped_episodes} episodes")
if skipped_episodes > 0:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Skipped {skipped_episodes} episodes with no images")
def merge_datasets(
source_folders, output_folder, validate_ts=False, tolerance_s=1e-4,
default_fps=30
):
"""
将多个数据集文件夹合并为一个,处理索引、维度和元数据
(Merge multiple dataset folders into one, handling indices, dimensions, and metadata)
Args:
source_folders (list): 源数据集文件夹路径列表 (List of source dataset folder paths)
output_folder (str): 输出文件夹路径 (Output folder path)
validate_ts (bool): 是否验证时间戳 (Whether to validate timestamps)
tolerance_s (float): 时间戳不连续性的容差值,以秒为单位 (Tolerance for timestamp discontinuities in seconds)
default_fps (float): 默认帧率 (Default frame rate)
"""
# Create output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)
os.makedirs(os.path.join(output_folder, "meta"), exist_ok=True)
fps = default_fps
print(f"使用默认FPS值: {fps}")
# Load episodes from all source folders
all_episodes = []
all_episodes_stats = []
all_tasks = []
total_frames = 0
total_episodes = 0
# Keep track of episode mapping (old_folder, old_index, new_index)
episode_mapping = []
# Collect all stats for proper merging
all_stats_data = []
# 添加一个变量来跟踪累积的帧数
cumulative_frame_count = 0
# 创建一个映射用于存储每个新的episode索引对应的起始帧索引
episode_to_frame_index = {}
# 创建一个映射,用于跟踪旧的任务描述到新任务索引的映射
task_desc_to_new_index = {}
# 创建一个映射,用于存储每个源文件夹和旧任务索引到新任务索引的映射
folder_task_mapping = {}
# 首先收集所有不同的任务描述
all_unique_tasks = []
# 从info.json获取chunks_size
info_path = os.path.join(source_folders[0], "meta", "info.json")
# Check if all source folders have images directory
images_dir_exists = all(os.path.exists(os.path.join(folder, "images")) for folder in source_folders)
chunks_size = 1000 # 默认值
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
chunks_size = info.get("chunks_size", 1000)
# 使用更简单的方法计算视频总数 (Use simpler method to calculate total videos)
total_videos = 0
for folder in source_folders:
try:
# 从每个数据集的info.json直接获取total_videos
# (Get total_videos directly from each dataset's info.json)
folder_info_path = os.path.join(folder, "meta", "info.json")
if os.path.exists(folder_info_path):
with open(folder_info_path) as f:
folder_info = json.load(f)
if "total_videos" in folder_info:
folder_videos = folder_info["total_videos"]
total_videos += folder_videos
print(
f"{folder}的info.json中读取到视频数量: {folder_videos} (Read video count from {folder}'s info.json: {folder_videos})"
)
# Load episodes
episodes_path = os.path.join(folder, "meta", "episodes.jsonl")
if not os.path.exists(episodes_path):
print(f"Warning: Episodes file not found in {folder}, skipping")
continue
episodes = load_jsonl(episodes_path)
# Load episode stats
episodes_stats_path = os.path.join(folder, "meta", "episodes_stats.jsonl")
episodes_stats = []
if os.path.exists(episodes_stats_path):
episodes_stats = load_jsonl(episodes_stats_path)
# Create a mapping of episode_index to stats
stats_map = {}
for stat in episodes_stats:
if "episode_index" in stat:
stats_map[stat["episode_index"]] = stat
# Load tasks
tasks_path = os.path.join(folder, "meta", "tasks.jsonl")
folder_tasks = []
if os.path.exists(tasks_path):
folder_tasks = load_jsonl(tasks_path)
# 创建此文件夹的任务映射
folder_task_mapping[folder] = {}
# 处理每个任务
for task in folder_tasks:
task_desc = task["task"]
old_index = task["task_index"]
# 检查任务描述是否已存在
if task_desc not in task_desc_to_new_index:
# 添加新任务描述,分配新索引
new_index = len(all_unique_tasks)
task_desc_to_new_index[task_desc] = new_index
all_unique_tasks.append({"task_index": new_index, "task": task_desc})
# 保存此文件夹中旧索引到新索引的映射
folder_task_mapping[folder][old_index] = task_desc_to_new_index[task_desc]
# Process all episodes from this folder
for episode in episodes:
old_index = episode["episode_index"]
new_index = total_episodes
# Update episode index
episode["episode_index"] = new_index
all_episodes.append(episode)
# Update stats if available
if old_index in stats_map:
stats = stats_map[old_index]
stats["episode_index"] = new_index
all_episodes_stats.append(stats)
# Add to all_stats_data for proper merging
if "stats" in stats:
all_stats_data.append(stats["stats"])
# Add to mapping
episode_mapping.append((folder, old_index, new_index))
# Update counters
total_episodes += 1
total_frames += episode["length"]
# 处理每个episode时收集此信息
episode_to_frame_index[new_index] = cumulative_frame_count
cumulative_frame_count += episode["length"]
# 使用收集的唯一任务列表替换之前的任务处理逻辑
all_tasks = all_unique_tasks
except Exception as e:
print(f"Error processing folder {folder}: {e}")
continue
print(f"Processed {total_episodes} episodes from {len(source_folders)} folders")
# Save combined episodes and stats
save_jsonl(all_episodes, os.path.join(output_folder, "meta", "episodes.jsonl"))
save_jsonl(all_episodes_stats, os.path.join(output_folder, "meta", "episodes_stats.jsonl"))
save_jsonl(all_tasks, os.path.join(output_folder, "meta", "tasks.jsonl"))
# Merge and save stats
stats_list = []
for folder in source_folders:
stats_path = os.path.join(folder, "meta", "stats.json")
if os.path.exists(stats_path):
with open(stats_path) as f:
stats = json.load(f)
stats_list.append(stats)
if stats_list:
# Merge global stats
merged_stats = merge_stats(stats_list)
# Update merged stats with episode-specific stats if available
if all_stats_data:
# For each feature in the stats
for feature in merged_stats:
if feature in all_stats_data[0]:
# Recalculate statistics based on all episodes
values = [stat[feature] for stat in all_stats_data if feature in stat]
# Find the maximum dimension for this feature
max_dim = max(
len(np.array(val.get("mean", [0])).flatten()) for val in values if "mean" in val
)
# Update count
if "count" in merged_stats[feature]:
merged_stats[feature]["count"] = [
sum(stat.get("count", [0])[0] for stat in values if "count" in stat)
]
# Update min/max with padding
if "min" in merged_stats[feature] and all("min" in stat for stat in values):
# Pad min values
padded_mins = []
for val in values:
val_array = np.array(val["min"])
val_flat = val_array.flatten()
if len(val_flat) < max_dim:
padded = np.zeros(max_dim)
padded[: len(val_flat)] = val_flat
padded_mins.append(padded)
else:
padded_mins.append(val_flat)
merged_stats[feature]["min"] = np.minimum.reduce(padded_mins).tolist()
if "max" in merged_stats[feature] and all("max" in stat for stat in values):
# Pad max values
padded_maxs = []
for val in values:
val_array = np.array(val["max"])
val_flat = val_array.flatten()
if len(val_flat) < max_dim:
padded = np.zeros(max_dim)
padded[: len(val_flat)] = val_flat
padded_maxs.append(padded)
else:
padded_maxs.append(val_flat)
merged_stats[feature]["max"] = np.maximum.reduce(padded_maxs).tolist()
# Update mean and std (weighted by count if available)
if "mean" in merged_stats[feature] and all("mean" in stat for stat in values):
# Pad mean values
padded_means = []
for val in values:
val_array = np.array(val["mean"])
val_flat = val_array.flatten()
if len(val_flat) < max_dim:
padded = np.zeros(max_dim)
padded[: len(val_flat)] = val_flat
padded_means.append(padded)
else:
padded_means.append(val_flat)
if all("count" in stat for stat in values):
counts = [stat["count"][0] for stat in values]
total_count = sum(counts)
weighted_means = [
mean * count / total_count
for mean, count in zip(padded_means, counts, strict=False)
]
merged_stats[feature]["mean"] = np.sum(weighted_means, axis=0).tolist()
else:
merged_stats[feature]["mean"] = np.mean(padded_means, axis=0).tolist()
if "std" in merged_stats[feature] and all("std" in stat for stat in values):
# Pad std values
padded_stds = []
for val in values:
val_array = np.array(val["std"])
val_flat = val_array.flatten()
if len(val_flat) < max_dim:
padded = np.zeros(max_dim)
padded[: len(val_flat)] = val_flat
padded_stds.append(padded)
else:
padded_stds.append(val_flat)
if all("count" in stat for stat in values):
counts = [stat["count"][0] for stat in values]
total_count = sum(counts)
variances = [std**2 for std in padded_stds]
weighted_variances = [
var * count / total_count
for var, count in zip(variances, counts, strict=False)
]
merged_stats[feature]["std"] = np.sqrt(
np.sum(weighted_variances, axis=0)
).tolist()
else:
# Simple average of standard deviations
merged_stats[feature]["std"] = np.mean(padded_stds, axis=0).tolist()
with open(os.path.join(output_folder, "meta", "stats.json"), "w") as f:
json.dump(merged_stats, f, indent=4)
# Update and save info.json
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
# Update info with correct counts
info["total_episodes"] = total_episodes
info["total_frames"] = total_frames
info["total_tasks"] = len(all_tasks)
info["total_chunks"] = (total_episodes + info["chunks_size"] - 1) // info[
"chunks_size"
] # Ceiling division
# Update splits
info["splits"] = {"train": f"0:{total_episodes}"}
# 更新视频总数 (Update total videos)
info["total_videos"] = total_videos
print(f"更新视频总数为: {total_videos} (Update total videos to: {total_videos})")
with open(os.path.join(output_folder, "meta", "info.json"), "w") as f:
json.dump(info, f, indent=4)
# Validate before video copying
if images_dir_exists:
early_validation(
source_folders,
episode_mapping,
)
# Copy video and data files
copy_videos(source_folders, output_folder, episode_mapping)
copy_data_files(
source_folders,
output_folder,
episode_mapping,
fps=fps,
episode_to_frame_index=episode_to_frame_index,
folder_task_mapping=folder_task_mapping,
chunks_size=chunks_size,
)
# Copy images and check with video frames
if args.copy_images:
print("Starting to copy images and validate video frame counts")
copy_images(source_folders, output_folder, episode_mapping)
print(f"Merged {total_episodes} episodes with {total_frames} frames into {output_folder}")
if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(description="Merge datasets from multiple sources.")
# Add arguments
parser.add_argument("--sources", nargs="+", required=True, help="List of source folder paths")
parser.add_argument("--output", required=True, help="Output folder path")
parser.add_argument("--fps", type=int, default=30, help="Your datasets FPS (default: 20)")
parser.add_argument("--copy_images", action="store_true", help="Whether to copy images (default: False)")
# Parse arguments
args = parser.parse_args()
# Use parsed arguments
merge_datasets(
args.sources,
args.output,
default_fps=args.fps
)