[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
bb69cb3c8c
commit
85fe8a3f4e
@@ -364,10 +364,16 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
for transform in transforms:
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
assert any(
|
||||
transform_dir.iterdir()
|
||||
), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
"min.png",
|
||||
"max.png",
|
||||
"mean.png",
|
||||
]
|
||||
for file_name in expected_files:
|
||||
assert (transform_dir / file_name).exists(), (
|
||||
f"{file_name} was not found in {transform} directory."
|
||||
|
||||
@@ -187,7 +187,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -202,7 +204,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -292,7 +296,9 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
image_arrays = [
|
||||
img_array_factory(height=500, width=500) for _ in range(num_images)
|
||||
]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -44,13 +44,23 @@ def make_new_buffer(
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
def make_spoof_data_frames(
|
||||
n_episodes: int, n_frames_per_episode: int
|
||||
) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
data_key: np.arange(
|
||||
n_frames_per_episode * n_episodes * np.prod(data_shape)
|
||||
).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
|
||||
np.arange(n_episodes), n_frames_per_episode
|
||||
),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode), n_episodes
|
||||
),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode) / fps, n_episodes
|
||||
),
|
||||
}
|
||||
return new_data
|
||||
|
||||
@@ -219,47 +229,72 @@ def test_compute_sampler_weights_trivial(
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
make_spoof_data_frames(
|
||||
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
|
||||
)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
|
||||
)
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
|
||||
)
|
||||
expected_weights /= expected_weights.sum()
|
||||
torch.testing.assert_close(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_factory, tmp_path
|
||||
):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.8,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
@@ -268,9 +303,13 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=2
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
|
||||
Reference in New Issue
Block a user