Fix tests

This commit is contained in:
Simon Alibert
2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

View File

@@ -213,15 +213,13 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
lerobot_dataset_from_episodes_factory,
lerobot_dataset_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_from_episodes_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
@@ -241,9 +239,9 @@ def test_compute_sampler_weights_trivial(
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path):
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
@@ -255,11 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_from_episodes_factory, tmp_path
):
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
@@ -270,9 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
)
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path):
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2)
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))