[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
76df8a31b3
commit
38f5fa4523
@@ -44,13 +44,23 @@ def make_new_buffer(
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
def make_spoof_data_frames(
|
||||
n_episodes: int, n_frames_per_episode: int
|
||||
) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
data_key: np.arange(
|
||||
n_frames_per_episode * n_episodes * np.prod(data_shape)
|
||||
).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
|
||||
np.arange(n_episodes), n_frames_per_episode
|
||||
),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode), n_episodes
|
||||
),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode) / fps, n_episodes
|
||||
),
|
||||
}
|
||||
return new_data
|
||||
|
||||
@@ -219,47 +229,72 @@ def test_compute_sampler_weights_trivial(
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
make_spoof_data_frames(
|
||||
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
|
||||
)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
|
||||
)
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
|
||||
)
|
||||
expected_weights /= expected_weights.sum()
|
||||
torch.testing.assert_close(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_factory, tmp_path
|
||||
):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.8,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
@@ -268,9 +303,13 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=2
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
|
||||
Reference in New Issue
Block a user