most unit tests passing (TODO: convert datasets)

This commit is contained in:
Remi Cadene
2025-04-16 21:30:58 +02:00
parent c2a05a1fde
commit 6b6a990f4c
22 changed files with 150 additions and 136 deletions

View File

@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
sampler = EpisodeAwareSampler(
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
)
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
@@ -80,11 +82,11 @@ def test_shuffle():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}