Files
issacdataengine/policy/openpi-InternData-A1/examples/arx/action_stats.py
2026-03-17 23:05:23 +08:00

213 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import deque
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
import datasets
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def check_final(
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
*,
# 索引与初始状态
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度max-min阈值
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2每步阈值度/步)
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”L2
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
) -> bool:
"""
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
"""
# --- 数据整理为 (N, D) tensor ---
if isinstance(last_states, torch.Tensor):
states = last_states
else:
states = torch.as_tensor(list(last_states), dtype=torch.float32)
if states.ndim != 2:
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
N, D = states.shape
if D < arm_dofs:
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
if N < 2:
return False # 样本太少,无法判定稳定
# 取最近窗口
w = min(N, stability_window)
window = states[-w:] # (w, D)
arm = window[:, :arm_dofs] # (w, 6)
last_arm = arm[-1] # (6,)
last_gripper = float(window[-1, gripper_index])
# --- 1) 最近 w 帧“没有太大变化” ---
# 两个指标每关节rangemax-min要小、相邻帧的平均“速度”要小
deg2rad = torch.pi / 180.0
range_tol = per_joint_range_deg * deg2rad
speed_tol = mean_speed_deg * deg2rad
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
max_range = float(ranges.abs().max()) # 标量
diffs = arm[1:] - arm[:-1] # (w-1, 6)
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
# --- 2) 夹爪关闭 ---
gripper_closed = (last_gripper < gripper_closed_thresh)
# --- 3) 末帧与“初始”差距要大 ---
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
if init.numel() != arm_dofs:
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
dist_from_init = float(torch.linalg.norm(last_arm - init))
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
# 组合判定
return bool(stable and gripper_closed and far_from_init)
# return bool(gripper_closed and far_from_init)
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
"""
Quickly fetch the last frame of each episode in a LeRobotDataset.
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
Returns: list[dict], where each element contains the last frame info of one episode.
"""
# 1) Compute the global index of the last row for each episode.
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
# 2) Determine which columns to load.
# By default, exclude video/image columns to avoid triggering slow video decoding.
if keys is None:
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
last_rows = ds.hf_dataset.select(end_idxs)
# 4) Build a dictionary of tensors for each requested key.
out = []
col = {k: last_rows[k] for k in keys}
# Convert lists of tensors into stacked tensors for easier indexing.
for k, v in col.items():
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
col[k] = torch.stack(v[:])
# Iterate through each episodes last frame and build a dict with its values.
for i, ep_end in enumerate(end_idxs):
item = {}
for k in keys:
val = col[k][i]
# Unpack 0-dimensional tensors into Python scalars.
if torch.is_tensor(val) and val.ndim == 0:
val = val.item()
item[k] = val
# Map task_index back to the human-readable task string.
if "task_index" in item:
item["task"] = ds.meta.tasks[int(item["task_index"])]
out.append(item)
# 5) Optionally decode the actual image/video frame for each last timestamp.
if include_images and len(ds.meta.video_keys) > 0:
for i, ep_end in enumerate(end_idxs):
ep_idx = int(out[i]["episode_index"])
ts = float(out[i]["timestamp"])
# Prepare a query dictionary: one timestamp per camera key.
query_ts = {k: [ts] for k in ds.meta.video_keys}
# Decode video frames at the specified timestamps for this episode.
frames = ds._query_videos(query_ts, ep_idx)
# Attach the decoded frame tensors to the output dictionary.
for k, v in frames.items():
out[i][k] = v
return out
if __name__ == "__main__":
# Initialize your dataset (replace with your repo ID or local path).
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
last_infos = get_last_frames(ds, include_images=False)
# Stack all 'observation.state' vectors into a single tensor for further processing.
states = torch.stack([info['observation.state'] for info in last_infos])
# Extract the left-arm joint states (first 7 values of each state vector).
left_arm_states = states[:, 0:7]
mean_state = torch.mean(left_arm_states, dim=0)
std_state = torch.std(left_arm_states, dim=0)
# Print the collected metadata for verification.
print(last_infos)
# --- Run check_final per episode using the last <=50 states ---
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
STABILITY_WINDOW = 120 # must be consistent with check_final's default
# Determine which episodes to iterate
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
episode_flags = {}
num_true, num_false = 0, 0
for ep_idx in episode_indices:
# Global index range [from_idx, to_idx) for this episode
from_idx = int(ds.episode_data_index["from"][ep_idx])
to_idx = int(ds.episode_data_index["to"][ep_idx])
if to_idx - from_idx <= 0:
episode_flags[ep_idx] = False
num_false += 1
continue
# Take the last <= STABILITY_WINDOW frames from this episode
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
rows = ds.hf_dataset.select(idxs)
# Collect full "observation.state" (shape ~ [W, S])
s_col = rows["observation.state"]
if isinstance(s_col, datasets.arrow_dataset.Column):
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
else:
S = torch.stack(s_col) # already a list[tensor]
# Build the 7D small state per frame: first 6 joints + gripper
# (Assumes the gripper signal is at the last position of the full state vector)
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
# Run your stopping logic
ok = check_final(
small_states,
arm_dofs=EP_ARM_DOFS,
gripper_index=-1,
stability_window=STABILITY_WINDOW,
)
episode_flags[ep_idx] = bool(ok)
num_true += int(ok)
num_false += int(not ok)
# Summary
total_eps = len(episode_indices)
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
# List some failed episodes for quick inspection
failed_eps = [e for e, passed in episode_flags.items() if not passed]
print("Failed episode indices (first 20):", failed_eps[:20])