[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-28 17:20:38 +00:00
parent 808cf63221
commit c05e4835d0
16 changed files with 93 additions and 91 deletions

View File

@@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu
# Create config with the given parameters
config = MultiAdamConfig(**config_params)
optimizers = config.build(base_params_dict)
# Verify optimizer count and keys
assert len(optimizers) == len(expected_values)
assert set(optimizers.keys()) == set(expected_values.keys())
# Check that all optimizers are Adam instances
for opt in optimizers.values():
assert isinstance(opt, torch.optim.Adam)
# Verify hyperparameters for each optimizer
for name, expected in expected_values.items():
optimizer = optimizers[name]
@@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict):
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Verify that directories were created for each optimizer
for name in multi_optimizers.keys():
assert (tmp_path / name).is_dir()
@@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
multi_optimizers[name].step()
# Zero gradients for next steps
multi_optimizers[name].zero_grad()
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Create new optimizers with the same config
config = MultiAdamConfig(
lr=1e-3,
@@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
},
)
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify state dictionaries match
for name in multi_optimizers.keys():
torch.testing.assert_close(
multi_optimizers[name].state_dict(),
loaded_optimizers[name].state_dict()
)
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
@@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
},
)
optimizers = config.build(base_params_dict)
# Save optimizer states without any backward pass (empty state)
save_optimizer_state(optimizers, tmp_path)
# Create new optimizers with the same config
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify hyperparameters match even with empty state
for name, optimizer in optimizers.items():
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
# Verify state dictionaries match (they will be empty)
torch.testing.assert_close(
optimizer.state_dict()["param_groups"],
loaded_optimizers[name].state_dict()["param_groups"]
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
)