forked from tangger/lerobot
Format file
This commit is contained in:
@@ -227,9 +227,7 @@ def test_compute_sampler_weights_trivial(
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
@@ -248,13 +246,10 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
@@ -264,14 +259,10 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.8,
|
||||
online_drop_n_last_frames=1,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user