address comments

This commit is contained in:
Cadene
2024-04-24 20:57:09 +00:00
parent bc96284ca0
commit 0ec28bf71a
9 changed files with 74 additions and 57 deletions

View File

@@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soar): make it work for tdmpc
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
],
)
def test_normalize(insert_temporal_dim):
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
# TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
input_shapes = {
"observation.image": [3, 96, 96],
@@ -170,7 +169,8 @@ def test_normalize(insert_temporal_dim):
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
normalize(input_batch)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
@@ -183,7 +183,8 @@ def test_normalize(insert_temporal_dim):
# test wihtout stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
unnormalize(output_batch)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)