Update pre-commits (#733)
This commit is contained in:
@@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
@@ -336,9 +336,9 @@ def test_backward_compatibility(repo_id):
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
for key in new_frame:
|
||||
assert torch.isclose(
|
||||
new_frame[key], old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
|
||||
f"{key=} for index={i} does not contain the same value"
|
||||
)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
@@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
assert any(combined_transforms_dir.iterdir()), (
|
||||
"No transformed images found in combined transforms directory."
|
||||
)
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
assert (combined_transforms_dir / f"{i}.png").exists(), (
|
||||
f"Combined transform image {i}.png was not found."
|
||||
)
|
||||
|
||||
|
||||
def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
@@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
for file_name in expected_files:
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
).exists(), f"{file_name} was not found in {transform} directory."
|
||||
assert (transform_dir / file_name).exists(), (
|
||||
f"{file_name} was not found in {transform} directory."
|
||||
)
|
||||
|
||||
@@ -132,9 +132,9 @@ def test_fifo():
|
||||
buffer.add_data(new_data)
|
||||
n_more_episodes = 2
|
||||
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
||||
assert (
|
||||
n_episodes + n_more_episodes
|
||||
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
|
||||
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
|
||||
"Something went wrong with the test code."
|
||||
)
|
||||
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
||||
buffer.add_data(more_new_data)
|
||||
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
||||
@@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
||||
"Padding does not match expected values"
|
||||
)
|
||||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
|
||||
@@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
observation_ = deepcopy(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation).cpu().numpy()
|
||||
assert set(observation) == set(
|
||||
observation_
|
||||
), "Observation batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(observation[k], observation_[k]) for k in observation
|
||||
), "Observation batch values are not the same after a forward pass."
|
||||
assert set(observation) == set(observation_), (
|
||||
"Observation batch keys are not the same after a forward pass."
|
||||
)
|
||||
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
|
||||
"Observation batch values are not the same after a forward pass."
|
||||
)
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
Reference in New Issue
Block a user