Format file
This commit is contained in:
@@ -61,10 +61,7 @@ def test_sample_indices():
|
||||
assert len(indices) == estimate_num_samples(10)
|
||||
|
||||
|
||||
@patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
|
||||
side_effect=mock_load_image_as_numpy,
|
||||
)
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
@@ -148,8 +145,7 @@ def test_compute_episode_stats():
|
||||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
|
||||
side_effect=mock_load_image_as_numpy,
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
@@ -237,13 +233,7 @@ def test_aggregate_stats():
|
||||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {
|
||||
"min": 1,
|
||||
"max": 10,
|
||||
"mean": 5.5,
|
||||
"std": 2.87,
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
@@ -254,13 +244,7 @@ def test_aggregate_stats():
|
||||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {
|
||||
"min": 2,
|
||||
"max": 15,
|
||||
"mean": 8.5,
|
||||
"std": 3.42,
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
@@ -322,9 +306,6 @@ def test_aggregate_stats():
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"],
|
||||
expected_agg_stats[fkey]["std"],
|
||||
atol=1e-04,
|
||||
rtol=1e-04,
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
||||
@@ -104,8 +104,7 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n",
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1)})
|
||||
|
||||
@@ -114,8 +113,7 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n",
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||
):
|
||||
dataset.add_frame({"task": "Dummy task"})
|
||||
|
||||
@@ -124,8 +122,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
|
||||
@@ -134,8 +131,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
|
||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
|
||||
|
||||
@@ -54,9 +54,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(
|
||||
fps: int = 30,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
@@ -93,9 +91,7 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30,
|
||||
keys: list = DUMMY_MOTOR_FEATURES,
|
||||
min_max_range: tuple[int, int] = (-10, 10),
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
@@ -33,9 +33,7 @@ from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import (
|
||||
ARTIFACT_DIR,
|
||||
)
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -93,8 +91,7 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})},
|
||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
@@ -117,8 +114,7 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})},
|
||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
@@ -371,11 +367,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
"min.png",
|
||||
"max.png",
|
||||
"mean.png",
|
||||
]
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
for file_name in expected_files:
|
||||
assert (transform_dir / file_name).exists(), (
|
||||
f"{file_name} was not found in {transform} directory."
|
||||
|
||||
@@ -227,9 +227,7 @@ def test_compute_sampler_weights_trivial(
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
@@ -248,13 +246,10 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
@@ -264,14 +259,10 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.8,
|
||||
online_drop_n_last_frames=1,
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,9 +15,7 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
|
||||
@@ -18,13 +18,8 @@ import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
|
||||
Reference in New Issue
Block a user