Format file

This commit is contained in:
AdilZouitine
2025-05-07 10:26:18 +02:00
parent adbf8bb85e
commit b36ec31fea
13 changed files with 43 additions and 169 deletions

View File

@@ -227,9 +227,7 @@ def test_compute_sampler_weights_trivial(
)
weights = compute_sampler_weights(
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
@@ -248,13 +246,10 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
torch.testing.assert_close(
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
@@ -264,14 +259,10 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
torch.testing.assert_close(
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)