[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -134,15 +134,15 @@ def test_multi_adam_configuration(base_params_dict, config_params, expected_valu
|
||||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
@@ -166,7 +166,7 @@ def multi_optimizers(base_params_dict):
|
||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers.keys():
|
||||
assert (tmp_path / name).is_dir()
|
||||
@@ -185,10 +185,10 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
@@ -199,16 +199,13 @@ def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers,
|
||||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers.keys():
|
||||
torch.testing.assert_close(
|
||||
multi_optimizers[name].state_dict(),
|
||||
loaded_optimizers[name].state_dict()
|
||||
)
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
@@ -223,25 +220,23 @@ def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"],
|
||||
loaded_optimizers[name].state_dict()["param_groups"]
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user