address comments
This commit is contained in:
@@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soar): make it work for tdmpc
|
||||
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
|
||||
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
|
||||
@@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
],
|
||||
)
|
||||
def test_normalize(insert_temporal_dim):
|
||||
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
|
||||
# TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
@@ -170,7 +169,8 @@ def test_normalize(insert_temporal_dim):
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
normalize(input_batch)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
@@ -183,7 +183,8 @@ def test_normalize(insert_temporal_dim):
|
||||
|
||||
# test wihtout stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
unnormalize(output_batch)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
|
||||
Reference in New Issue
Block a user