Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -6,7 +6,6 @@ import torch
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
@@ -38,12 +37,14 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform)
for key in dataset.image_keys:
obs = preprocess_observation(obs)
# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key]
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model

View File

@@ -51,7 +51,7 @@ def test_examples_4_and_3():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
for file_name in ["model.pt", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/3_evaluate_pretrained_policy.py"

View File

@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
]
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader(
@@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform)
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@@ -86,8 +88,115 @@ def test_policy(env_name, policy_name, extra_overrides):
with torch.inference_mode():
action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, dataset.transform)
# convert action to cpu numpy array
action = postprocess_action(action)
# Test step through policy
env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
an exception when the forward pass is called without the stats having been provided.
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
expected.
"""
input_shapes = {
"observation.image": [3, 96, 96],
"observation.state": [10],
}
output_shapes = {
"action": [5],
}
normalize_input_modes = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes = {
"action": "min_max",
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test without stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)