Merge remote-tracking branch 'origin/main' into user/rcadene/2025_02_19_port_openx
This commit is contained in:
@@ -2,7 +2,6 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
@@ -91,9 +90,9 @@ def calculate_coverage(zarr_data):
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
@@ -119,7 +118,7 @@ def calculate_coverage(zarr_data):
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
@@ -181,20 +180,21 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
idx = i + (frame_idx < num_frames - 1)
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
"action": action[i],
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
"next.reward": reward[idx : idx + 1],
|
||||
"next.success": success[idx : idx + 1],
|
||||
"task": PUSHT_TASK,
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
frame["observation.state"] = agent_pos[i]
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
frame["observation.environment_state"] = keypoints[i]
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
frame["observation.image"] = image[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user