Organize test folders (#856)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Simon Alibert
2025-03-13 14:05:55 +01:00
committed by GitHub
parent a36ed39487
commit 974028bd28
79 changed files with 63 additions and 106 deletions

View File

@@ -0,0 +1,87 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Any
import pytest
from lerobot.common.utils.io_utils import deserialize_json_into_object
@pytest.fixture
def tmp_json_file(tmp_path: Path):
"""Writes `data` to a temporary JSON file and returns the file's path."""
def _write(data: Any) -> Path:
file_path = tmp_path / "data.json"
with file_path.open("w", encoding="utf-8") as f:
json.dump(data, f)
return file_path
return _write
def test_simple_dict(tmp_json_file):
data = {"name": "Alice", "age": 30}
json_path = tmp_json_file(data)
obj = {"name": "", "age": 0}
assert deserialize_json_into_object(json_path, obj) == data
def test_nested_structure(tmp_json_file):
data = {"items": [1, 2, 3], "info": {"active": True}}
json_path = tmp_json_file(data)
obj = {"items": [0, 0, 0], "info": {"active": False}}
assert deserialize_json_into_object(json_path, obj) == data
def test_tuple_conversion(tmp_json_file):
data = {"coords": [10.5, 20.5]}
json_path = tmp_json_file(data)
obj = {"coords": (0.0, 0.0)}
result = deserialize_json_into_object(json_path, obj)
assert result["coords"] == (10.5, 20.5)
def test_type_mismatch_raises(tmp_json_file):
data = {"numbers": {"bad": "structure"}}
json_path = tmp_json_file(data)
obj = {"numbers": [0, 0]}
with pytest.raises(TypeError):
deserialize_json_into_object(json_path, obj)
def test_missing_key_raises(tmp_json_file):
data = {"one": 1}
json_path = tmp_json_file(data)
obj = {"one": 0, "two": 0}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)
def test_extra_key_raises(tmp_json_file):
data = {"one": 1, "two": 2}
json_path = tmp_json_file(data)
obj = {"one": 0}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)
def test_list_length_mismatch_raises(tmp_json_file):
data = {"nums": [1, 2, 3]}
json_path = tmp_json_file(data)
obj = {"nums": [0, 0]}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)

View File

@@ -0,0 +1,120 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture
def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
def test_average_meter_initialization():
meter = AverageMeter("loss", ":.2f")
assert meter.name == "loss"
assert meter.fmt == ":.2f"
assert meter.val == 0.0
assert meter.avg == 0.0
assert meter.sum == 0.0
assert meter.count == 0.0
def test_average_meter_update():
meter = AverageMeter("accuracy")
meter.update(5, n=2)
assert meter.val == 5
assert meter.sum == 10
assert meter.count == 2
assert meter.avg == 5
def test_average_meter_reset():
meter = AverageMeter("loss")
meter.update(3, 4)
meter.reset()
assert meter.val == 0.0
assert meter.avg == 0.0
assert meter.sum == 0.0
assert meter.count == 0.0
def test_average_meter_str():
meter = AverageMeter("metric", ":.1f")
meter.update(4.567, 3)
assert str(meter) == "metric:4.6"
def test_metrics_tracker_initialization(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
assert "loss" in tracker.metrics
assert "accuracy" in tracker.metrics
def test_metrics_tracker_step(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
)
tracker.step()
assert tracker.steps == 6
assert tracker.samples == 6 * 32
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
assert tracker.loss == mock_metrics["loss"]
assert tracker.accuracy == mock_metrics["accuracy"]
with pytest.raises(AttributeError):
_ = tracker.non_existent_metric
def test_metrics_tracker_setattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss = 2.0
assert tracker.loss.val == 2.0
def test_metrics_tracker_str(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(3.456, 1)
tracker.accuracy.update(0.876, 1)
output = str(tracker)
assert "loss:3.456" in output
assert "accuracy:0.88" in output
def test_metrics_tracker_to_dict(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(5, 2)
metrics_dict = tracker.to_dict()
assert isinstance(metrics_dict, dict)
assert metrics_dict["loss"] == 5 # average value
assert metrics_dict["steps"] == tracker.steps
def test_metrics_tracker_reset_averages(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(10, 3)
tracker.accuracy.update(0.95, 5)
tracker.reset_averages()
assert tracker.loss.avg == 0.0
assert tracker.accuracy.avg == 0.0

View File

@@ -0,0 +1,122 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import pytest
import torch
from lerobot.common.utils.random_utils import (
deserialize_numpy_rng_state,
deserialize_python_rng_state,
deserialize_rng_state,
deserialize_torch_rng_state,
get_rng_state,
seeded_context,
serialize_numpy_rng_state,
serialize_python_rng_state,
serialize_rng_state,
serialize_torch_rng_state,
set_rng_state,
set_seed,
)
@pytest.fixture
def fixed_seed():
"""Fixture to set a consistent initial seed for each test."""
set_seed(12345)
yield
def test_serialize_deserialize_python_rng(fixed_seed):
# Save state after generating val1
_ = random.random()
st = serialize_python_rng_state()
# Next random is val2
val2 = random.random()
# Restore the state, so the next random should match val2
deserialize_python_rng_state(st)
val3 = random.random()
assert val2 == val3
def test_serialize_deserialize_numpy_rng(fixed_seed):
_ = np.random.rand()
st = serialize_numpy_rng_state()
val2 = np.random.rand()
deserialize_numpy_rng_state(st)
val3 = np.random.rand()
assert val2 == val3
def test_serialize_deserialize_torch_rng(fixed_seed):
_ = torch.rand(1).item()
st = serialize_torch_rng_state()
val2 = torch.rand(1).item()
deserialize_torch_rng_state(st)
val3 = torch.rand(1).item()
assert val2 == val3
def test_serialize_deserialize_rng(fixed_seed):
# Generate one from each library
_ = random.random()
_ = np.random.rand()
_ = torch.rand(1).item()
# Serialize
st = serialize_rng_state()
# Generate second set
val_py2 = random.random()
val_np2 = np.random.rand()
val_th2 = torch.rand(1).item()
# Restore, so the next draws should match val_py2, val_np2, val_th2
deserialize_rng_state(st)
assert random.random() == val_py2
assert np.random.rand() == val_np2
assert torch.rand(1).item() == val_th2
def test_get_set_rng_state(fixed_seed):
st = get_rng_state()
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
# Change states
random.random()
np.random.rand()
torch.rand(1)
# Restore
set_rng_state(st)
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert val1 == val2
def test_set_seed():
set_seed(1337)
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
set_seed(1337)
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert val1 == val2
def test_seeded_context(fixed_seed):
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
with seeded_context(1337):
seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item())
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
with seeded_context(1337):
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert seeded_val1 == seeded_val2
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting

View File

@@ -0,0 +1,97 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest.mock import Mock, patch
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
RNG_STATE,
SCHEDULER_STATE,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
def test_get_step_identifier():
assert get_step_identifier(5, 1000) == "000005"
assert get_step_identifier(123, 100_000) == "000123"
assert get_step_identifier(456789, 1_000_000) == "0456789"
def test_get_step_checkpoint_dir():
output_dir = Path("/checkpoints")
step_dir = get_step_checkpoint_dir(output_dir, 1000, 5)
assert step_dir == output_dir / CHECKPOINTS_DIR / "000005"
def test_save_load_training_step(tmp_path):
save_training_step(5000, tmp_path)
assert (tmp_path / TRAINING_STEP).is_file()
def test_load_training_step(tmp_path):
step = 5000
save_training_step(step, tmp_path)
loaded_step = load_training_step(tmp_path)
assert loaded_step == step
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()
update_last_checkpoint(checkpoint)
last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK
assert last_checkpoint.is_symlink()
assert last_checkpoint.resolve() == checkpoint
@patch("lerobot.common.utils.train_utils.save_training_state")
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
cfg = Mock()
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
policy.save_pretrained.assert_called_once()
cfg.save_pretrained.assert_called_once()
mock_save_training_state.assert_called_once()
def test_save_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file()
assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file()
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file()
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file()
assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file()
def test_save_load_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler