Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -6,7 +6,6 @@ import torch
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -38,12 +37,14 @@ def test_factory(env_name):
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||
for key in dataset.image_keys:
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
for key in obs:
|
||||
if "image" not in key:
|
||||
continue
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_examples_4_and_3():
|
||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/3_evaluate_pretrained_policy.py"
|
||||
|
||||
@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, dataset.transform)
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
@@ -86,8 +88,115 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, dataset.transform)
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"insert_temporal_dim",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
Reference in New Issue
Block a user