[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -31,7 +31,11 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
"data/action",
shape=(num_frames, 1),
chunks=(num_frames, 1),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/img",
@@ -41,20 +45,38 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
"data/n_contacts",
shape=(num_frames, 2),
chunks=(num_frames, 2),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
"data/state",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
"data/keypoint",
shape=(num_frames, 9, 2),
chunks=(num_frames, 9, 2),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
"meta/episode_ends",
shape=(num_episodes,),
chunks=(num_episodes,),
dtype=np.int32,
overwrite=True,
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/img"][:] = np.random.randint(
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
@@ -93,7 +115,11 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
"data/robot0_eef_pos",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
@@ -110,10 +136,16 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
"meta/episode_ends",
shape=(num_episodes,),
chunks=(num_episodes,),
dtype=np.int32,
overwrite=True,
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/camera0_rgb"][:] = np.random.randint(
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
@@ -129,7 +161,9 @@ def _mock_download_raw_xarm(raw_dir, num_frames=4):
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"rgb": np.random.randint(
0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8
),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
@@ -151,13 +185,24 @@ def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"action", data=np.random.randn(num_frames // num_episodes, 14)
)
f.create_dataset(
"observations/qpos",
data=np.random.randn(num_frames // num_episodes, 14),
)
f.create_dataset(
"observations/qvel",
data=np.random.randn(num_frames // num_episodes, 14),
)
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
0,
255,
size=(num_frames // num_episodes, 480, 640, 3),
dtype=np.uint8,
),
)
@@ -191,7 +236,12 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
frame = [
{
"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4",
"timestamp": frame_idx / fps,
}
]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
@@ -204,7 +254,9 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
imgs_array = np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
@@ -263,7 +315,9 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
def test_push_dataset_to_hub_format(
required_packages, tmpdir, raw_format, repo_id, make_test_data
):
num_episodes = 3
tmpdir = Path(tmpdir)
@@ -315,7 +369,10 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
)
for k in ["from", "to"]:
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
assert torch.equal(
test_dataset.episode_data_index[k],
lerobot_dataset.episode_data_index[k][:1],
)
@pytest.mark.parametrize(
@@ -359,8 +416,12 @@ def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, re
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
if isinstance(item1[key], torch.Tensor) and isinstance(
item2[key], torch.Tensor
):
assert torch.equal(
item1[key], item2[key]
), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"