forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -91,9 +91,7 @@ def test_metrics_tracker_step(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
assert tracker.accuracy == mock_metrics["accuracy"]
|
||||
with pytest.raises(AttributeError):
|
||||
@@ -101,17 +99,13 @@ def test_metrics_tracker_getattr(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss = 2.0
|
||||
assert tracker.loss.val == 2.0
|
||||
|
||||
|
||||
def test_metrics_tracker_str(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(3.456, 1)
|
||||
tracker.accuracy.update(0.876, 1)
|
||||
output = str(tracker)
|
||||
@@ -120,9 +114,7 @@ def test_metrics_tracker_str(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_to_dict(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(5, 2)
|
||||
metrics_dict = tracker.to_dict()
|
||||
assert isinstance(metrics_dict, dict)
|
||||
@@ -131,9 +123,7 @@ def test_metrics_tracker_to_dict(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(10, 3)
|
||||
tracker.accuracy.update(0.95, 5)
|
||||
tracker.reset_averages()
|
||||
|
||||
@@ -118,9 +118,5 @@ def test_seeded_context(fixed_seed):
|
||||
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
|
||||
assert seeded_val1 == seeded_val2
|
||||
assert all(
|
||||
a != b for a, b in zip(val1, seeded_val1, strict=True)
|
||||
) # changed inside the context
|
||||
assert all(
|
||||
a != b for a, b in zip(val2, seeded_val2, strict=True)
|
||||
) # changed again after exiting
|
||||
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
|
||||
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
|
||||
|
||||
@@ -91,9 +91,7 @@ def test_save_training_state(tmp_path, optimizer, scheduler):
|
||||
|
||||
def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
|
||||
tmp_path, optimizer, scheduler
|
||||
)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
||||
Reference in New Issue
Block a user