[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -31,11 +31,7 @@ from lerobot.common.datasets.compute_stats import (
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return (
|
||||
np.ones((3, 32, 32), dtype=dtype)
|
||||
if channel_first
|
||||
else np.ones((32, 32, 3), dtype=dtype)
|
||||
)
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -81,20 +77,9 @@ def test_sample_images(mock_load):
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
assert (
|
||||
"min" in stats
|
||||
and "max" in stats
|
||||
and "mean" in stats
|
||||
and "std" in stats
|
||||
and "count" in stats
|
||||
)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert (
|
||||
stats["min"].shape
|
||||
== stats["max"].shape
|
||||
== stats["mean"].shape
|
||||
== stats["std"].shape
|
||||
)
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
@@ -315,47 +300,31 @@ def test_aggregate_stats():
|
||||
for ep_stats in all_stats:
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
# cast to numpy
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
results = aggregate_stats(all_stats)
|
||||
|
||||
for fkey in expected_agg_stats:
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["min"], expected_agg_stats[fkey]["min"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["max"], expected_agg_stats[fkey]["max"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"],
|
||||
expected_agg_stats[fkey]["std"],
|
||||
atol=1e-04,
|
||||
rtol=1e-04,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["count"], expected_agg_stats[fkey]["count"]
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
||||
@@ -72,9 +72,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
@@ -129,9 +127,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
|
||||
):
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
|
||||
)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -141,9 +137,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
ValueError,
|
||||
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
|
||||
):
|
||||
dataset.add_frame(
|
||||
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
|
||||
)
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -151,9 +145,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
|
||||
),
|
||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
|
||||
@@ -175,9 +167,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
|
||||
),
|
||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||
|
||||
@@ -471,9 +461,7 @@ def test_flatten_unflatten_dict():
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
|
||||
f"{original_d} != {d}"
|
||||
)
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -527,13 +515,7 @@ def test_backward_compatibility(repo_id):
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
|
||||
@@ -71,12 +71,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 1.1
|
||||
) # Modify a single timestamp just outside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_timestamps
|
||||
@@ -87,12 +83,8 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 0.9
|
||||
) # Modify a single timestamp just inside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
@@ -105,9 +97,7 @@ def valid_delta_timestamps_factory():
|
||||
keys: list = DUMMY_MOTOR_FEATURES,
|
||||
min_max_range: tuple[int, int] = (-10, 10),
|
||||
) -> dict:
|
||||
delta_timestamps = {
|
||||
key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys
|
||||
}
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
@@ -144,9 +134,7 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(
|
||||
keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
@@ -198,9 +186,7 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
@@ -241,9 +227,7 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
|
||||
@@ -82,11 +82,7 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"brightness": min_max}
|
||||
)
|
||||
},
|
||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
@@ -98,11 +94,7 @@ def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={
|
||||
"contrast": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"contrast": min_max}
|
||||
)
|
||||
},
|
||||
tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
@@ -114,11 +106,7 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={
|
||||
"saturation": ImageTransformConfig(
|
||||
type="ColorJitter", kwargs={"saturation": min_max}
|
||||
)
|
||||
},
|
||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
@@ -142,11 +130,7 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={
|
||||
"sharpness": ImageTransformConfig(
|
||||
type="SharpnessJitter", kwargs={"sharpness": min_max}
|
||||
)
|
||||
},
|
||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
@@ -362,9 +346,7 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), (
|
||||
"Combined transforms directory was not created."
|
||||
)
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(combined_transforms_dir.iterdir()), (
|
||||
"No transformed images found in combined transforms directory."
|
||||
)
|
||||
@@ -386,9 +368,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
for transform in transforms:
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), (
|
||||
f"No transformed images found in {transform} directory."
|
||||
)
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
|
||||
@@ -187,9 +187,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -204,9 +202,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -296,9 +292,7 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [
|
||||
img_array_factory(height=500, width=500) for _ in range(num_images)
|
||||
]
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -44,23 +44,13 @@ def make_new_buffer(
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(
|
||||
n_episodes: int, n_frames_per_episode: int
|
||||
) -> dict[str, np.ndarray]:
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(
|
||||
n_frames_per_episode * n_episodes * np.prod(data_shape)
|
||||
).reshape(-1, *data_shape),
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
|
||||
np.arange(n_episodes), n_frames_per_episode
|
||||
),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode), n_episodes
|
||||
),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode) / fps, n_episodes
|
||||
),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
}
|
||||
return new_data
|
||||
|
||||
@@ -176,9 +166,7 @@ def test_delta_timestamps_within_tolerance():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
torch.testing.assert_close(
|
||||
data, torch.tensor([0, 2, 3]), msg="Data does not match expected values"
|
||||
)
|
||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
@@ -214,9 +202,7 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), (
|
||||
"Data does not match expected values"
|
||||
)
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
||||
"Padding does not match expected values"
|
||||
)
|
||||
@@ -233,15 +219,11 @@ def test_compute_sampler_weights_trivial(
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(
|
||||
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
|
||||
)
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
@@ -252,26 +234,18 @@ def test_compute_sampler_weights_trivial(
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat(
|
||||
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
|
||||
)
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat(
|
||||
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
|
||||
)
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights /= expected_weights.sum()
|
||||
torch.testing.assert_close(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
@@ -284,17 +258,11 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_factory, tmp_path
|
||||
):
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
@@ -309,13 +277,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=2
|
||||
)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
@@ -324,6 +288,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
|
||||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
|
||||
)
|
||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
|
||||
Reference in New Issue
Block a user