make load_state_dict work

This commit is contained in:
Cadene
2024-04-24 15:40:09 +00:00
parent 0660f71556
commit 72751b7cf6
9 changed files with 376 additions and 87 deletions

View File

@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@@ -93,3 +93,111 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soar): make it work for tdmpc
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
new_policy.load_state_dict(policy.state_dict())
new_policy.update(batch, step=0)
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
def test_normalize(insert_temporal_dim):
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
input_shapes = {
"observation.image": [3, 96, 96],
"observation.state": [10],
}
output_shapes = {
"action": [5],
}
normalize_input_modes = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes = {
"action": "min_max",
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
normalize(input_batch)
# test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test wihtout stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
if __name__ == "__main__":
test_policy(
*("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"])
)
# test_policy(insert_temporal_dim=True)