[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -44,13 +44,23 @@ def make_new_buffer(
return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
def make_spoof_data_frames(
n_episodes: int, n_frames_per_episode: int
) -> dict[str, np.ndarray]:
new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
data_key: np.arange(
n_frames_per_episode * n_episodes * np.prod(data_shape)
).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
np.arange(n_episodes), n_frames_per_episode
),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
np.arange(n_frames_per_episode), n_episodes
),
OnlineBuffer.TIMESTAMP_KEY: np.tile(
np.arange(n_frames_per_episode) / fps, n_episodes
),
}
return new_data
@@ -133,8 +143,8 @@ def test_fifo():
n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`).
assert (
n_episodes + n_more_episodes
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
(n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity
), "Something went wrong with the test code."
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full."
@@ -166,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert torch.allclose(
data, torch.tensor([0, 2, 3])
), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@@ -202,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
data, torch.tensor([0, 0, 2, 4, 4])
), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
@@ -219,58 +233,89 @@ def test_compute_sampler_weights_trivial(
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
make_spoof_data_frames(
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
expected_weights = torch.cat(
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
)
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights = torch.cat(
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
)
expected_weights /= expected_weights.sum()
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
)
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
)
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=2
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset,
@@ -279,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
assert torch.allclose(
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
)