[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -31,7 +31,11 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||
"data/action",
|
||||
shape=(num_frames, 1),
|
||||
chunks=(num_frames, 1),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/img",
|
||||
@@ -41,20 +45,38 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||
"data/n_contacts",
|
||||
shape=(num_frames, 2),
|
||||
chunks=(num_frames, 2),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
"data/state",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||
"data/keypoint",
|
||||
shape=(num_frames, 9, 2),
|
||||
chunks=(num_frames, 9, 2),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
"meta/episode_ends",
|
||||
shape=(num_episodes,),
|
||||
chunks=(num_episodes,),
|
||||
dtype=np.int32,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/img"][:] = np.random.randint(
|
||||
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
|
||||
)
|
||||
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||
@@ -93,7 +115,11 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
"data/robot0_eef_pos",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
@@ -110,10 +136,16 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
"meta/episode_ends",
|
||||
shape=(num_episodes,),
|
||||
chunks=(num_episodes,),
|
||||
dtype=np.int32,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(
|
||||
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
|
||||
)
|
||||
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||
@@ -129,7 +161,9 @@ def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||
|
||||
dataset_dict = {
|
||||
"observations": {
|
||||
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||
"rgb": np.random.randint(
|
||||
0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8
|
||||
),
|
||||
"state": np.random.randn(num_frames, 4),
|
||||
},
|
||||
"actions": np.random.randn(num_frames, 3),
|
||||
@@ -151,13 +185,24 @@ def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||
with h5py.File(str(path_h5), "w") as f:
|
||||
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset(
|
||||
"action", data=np.random.randn(num_frames // num_episodes, 14)
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/qpos",
|
||||
data=np.random.randn(num_frames // num_episodes, 14),
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/qvel",
|
||||
data=np.random.randn(num_frames // num_episodes, 14),
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/images/top",
|
||||
data=np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
0,
|
||||
255,
|
||||
size=(num_frames // num_episodes, 480, 640, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -191,7 +236,12 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
action = np.random.randn(21).tolist()
|
||||
state = np.random.randn(21).tolist()
|
||||
ep_idx = episode_indices_mapping[i]
|
||||
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||
frame = [
|
||||
{
|
||||
"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4",
|
||||
"timestamp": frame_idx / fps,
|
||||
}
|
||||
]
|
||||
timestamps.append(t_utc)
|
||||
actions.append(action)
|
||||
states.append(state)
|
||||
@@ -204,7 +254,9 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
|
||||
# write fake mp4 file for each episode
|
||||
for ep_idx in range(num_episodes):
|
||||
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||
imgs_array = np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
)
|
||||
|
||||
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
@@ -263,7 +315,9 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
|
||||
def test_push_dataset_to_hub_format(
|
||||
required_packages, tmpdir, raw_format, repo_id, make_test_data
|
||||
):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
@@ -315,7 +369,10 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
||||
)
|
||||
for k in ["from", "to"]:
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
assert torch.equal(
|
||||
test_dataset.episode_data_index[k],
|
||||
lerobot_dataset.episode_data_index[k][:1],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -359,8 +416,12 @@ def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, re
|
||||
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||
|
||||
for key in item1:
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(
|
||||
item2[key], torch.Tensor
|
||||
):
|
||||
assert torch.equal(
|
||||
item1[key], item2[key]
|
||||
), f"Mismatch found in key: {key}"
|
||||
else:
|
||||
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user