213 lines
9.6 KiB
Python
213 lines
9.6 KiB
Python
from collections import deque
|
||
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
|
||
import datasets
|
||
import torch
|
||
|
||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||
|
||
|
||
def check_final(
|
||
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
|
||
*,
|
||
# 索引与初始状态
|
||
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6)
|
||
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
|
||
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
|
||
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
|
||
|
||
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
|
||
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
|
||
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度(max-min)阈值(度)
|
||
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2(每步)阈值(度/步)
|
||
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”(L2,度)
|
||
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
|
||
) -> bool:
|
||
"""
|
||
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
|
||
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
|
||
"""
|
||
# --- 数据整理为 (N, D) tensor ---
|
||
if isinstance(last_states, torch.Tensor):
|
||
states = last_states
|
||
else:
|
||
states = torch.as_tensor(list(last_states), dtype=torch.float32)
|
||
|
||
if states.ndim != 2:
|
||
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
|
||
N, D = states.shape
|
||
if D < arm_dofs:
|
||
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
|
||
if N < 2:
|
||
return False # 样本太少,无法判定稳定
|
||
|
||
# 取最近窗口
|
||
w = min(N, stability_window)
|
||
window = states[-w:] # (w, D)
|
||
arm = window[:, :arm_dofs] # (w, 6)
|
||
last_arm = arm[-1] # (6,)
|
||
last_gripper = float(window[-1, gripper_index])
|
||
|
||
# --- 1) 最近 w 帧“没有太大变化” ---
|
||
# 两个指标:每关节range(max-min)要小、相邻帧的平均“速度”要小
|
||
deg2rad = torch.pi / 180.0
|
||
range_tol = per_joint_range_deg * deg2rad
|
||
speed_tol = mean_speed_deg * deg2rad
|
||
|
||
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
|
||
max_range = float(ranges.abs().max()) # 标量
|
||
diffs = arm[1:] - arm[:-1] # (w-1, 6)
|
||
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
|
||
|
||
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
|
||
|
||
# --- 2) 夹爪关闭 ---
|
||
gripper_closed = (last_gripper < gripper_closed_thresh)
|
||
|
||
# --- 3) 末帧与“初始”差距要大 ---
|
||
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
|
||
if init.numel() != arm_dofs:
|
||
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
|
||
dist_from_init = float(torch.linalg.norm(last_arm - init))
|
||
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
|
||
|
||
# 组合判定
|
||
return bool(stable and gripper_closed and far_from_init)
|
||
# return bool(gripper_closed and far_from_init)
|
||
|
||
|
||
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
|
||
"""
|
||
Quickly fetch the last frame of each episode in a LeRobotDataset.
|
||
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
|
||
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
|
||
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
|
||
Returns: list[dict], where each element contains the last frame info of one episode.
|
||
"""
|
||
# 1) Compute the global index of the last row for each episode.
|
||
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
|
||
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
|
||
|
||
# 2) Determine which columns to load.
|
||
# By default, exclude video/image columns to avoid triggering slow video decoding.
|
||
if keys is None:
|
||
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
|
||
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
|
||
|
||
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
|
||
last_rows = ds.hf_dataset.select(end_idxs)
|
||
|
||
# 4) Build a dictionary of tensors for each requested key.
|
||
out = []
|
||
col = {k: last_rows[k] for k in keys}
|
||
|
||
# Convert lists of tensors into stacked tensors for easier indexing.
|
||
for k, v in col.items():
|
||
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
|
||
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
|
||
col[k] = torch.stack(v[:])
|
||
|
||
# Iterate through each episode’s last frame and build a dict with its values.
|
||
for i, ep_end in enumerate(end_idxs):
|
||
item = {}
|
||
for k in keys:
|
||
val = col[k][i]
|
||
# Unpack 0-dimensional tensors into Python scalars.
|
||
if torch.is_tensor(val) and val.ndim == 0:
|
||
val = val.item()
|
||
item[k] = val
|
||
|
||
# Map task_index back to the human-readable task string.
|
||
if "task_index" in item:
|
||
item["task"] = ds.meta.tasks[int(item["task_index"])]
|
||
out.append(item)
|
||
|
||
# 5) Optionally decode the actual image/video frame for each last timestamp.
|
||
if include_images and len(ds.meta.video_keys) > 0:
|
||
for i, ep_end in enumerate(end_idxs):
|
||
ep_idx = int(out[i]["episode_index"])
|
||
ts = float(out[i]["timestamp"])
|
||
# Prepare a query dictionary: one timestamp per camera key.
|
||
query_ts = {k: [ts] for k in ds.meta.video_keys}
|
||
# Decode video frames at the specified timestamps for this episode.
|
||
frames = ds._query_videos(query_ts, ep_idx)
|
||
# Attach the decoded frame tensors to the output dictionary.
|
||
for k, v in frames.items():
|
||
out[i][k] = v
|
||
|
||
return out
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Initialize your dataset (replace with your repo ID or local path).
|
||
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
|
||
|
||
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
|
||
last_infos = get_last_frames(ds, include_images=False)
|
||
|
||
# Stack all 'observation.state' vectors into a single tensor for further processing.
|
||
states = torch.stack([info['observation.state'] for info in last_infos])
|
||
# Extract the left-arm joint states (first 7 values of each state vector).
|
||
left_arm_states = states[:, 0:7]
|
||
mean_state = torch.mean(left_arm_states, dim=0)
|
||
std_state = torch.std(left_arm_states, dim=0)
|
||
|
||
# Print the collected metadata for verification.
|
||
print(last_infos)
|
||
|
||
# --- Run check_final per episode using the last <=50 states ---
|
||
|
||
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
|
||
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
|
||
STABILITY_WINDOW = 120 # must be consistent with check_final's default
|
||
|
||
# Determine which episodes to iterate
|
||
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
|
||
|
||
episode_flags = {}
|
||
num_true, num_false = 0, 0
|
||
|
||
for ep_idx in episode_indices:
|
||
# Global index range [from_idx, to_idx) for this episode
|
||
from_idx = int(ds.episode_data_index["from"][ep_idx])
|
||
to_idx = int(ds.episode_data_index["to"][ep_idx])
|
||
|
||
if to_idx - from_idx <= 0:
|
||
episode_flags[ep_idx] = False
|
||
num_false += 1
|
||
continue
|
||
|
||
# Take the last <= STABILITY_WINDOW frames from this episode
|
||
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
|
||
rows = ds.hf_dataset.select(idxs)
|
||
|
||
# Collect full "observation.state" (shape ~ [W, S])
|
||
s_col = rows["observation.state"]
|
||
if isinstance(s_col, datasets.arrow_dataset.Column):
|
||
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
|
||
else:
|
||
S = torch.stack(s_col) # already a list[tensor]
|
||
|
||
# Build the 7D small state per frame: first 6 joints + gripper
|
||
# (Assumes the gripper signal is at the last position of the full state vector)
|
||
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
|
||
|
||
# Run your stopping logic
|
||
ok = check_final(
|
||
small_states,
|
||
arm_dofs=EP_ARM_DOFS,
|
||
gripper_index=-1,
|
||
stability_window=STABILITY_WINDOW,
|
||
)
|
||
episode_flags[ep_idx] = bool(ok)
|
||
num_true += int(ok)
|
||
num_false += int(not ok)
|
||
|
||
# Summary
|
||
total_eps = len(episode_indices)
|
||
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
|
||
|
||
# List some failed episodes for quick inspection
|
||
failed_eps = [e for e, passed in episode_flags.items() if not passed]
|
||
print("Failed episode indices (first 20):", failed_eps[:20])
|
||
|