Move normalize/unnormalize transforms to policy for act and diffusion

This commit is contained in:
Cadene
2024-04-20 21:08:14 +00:00
parent c1bcf857c5
commit 42ed7bb670
19 changed files with 145 additions and 195 deletions

View File

@@ -42,7 +42,7 @@ def test_factory(env_name):
env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform)
obs = preprocess_observation(obs)
for key in dataset.image_keys:
img = obs[key]
assert img.dtype == torch.float32

View File

@@ -51,7 +51,7 @@ def test_examples_4_and_3():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
for file_name in ["model.pt", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/3_evaluate_pretrained_policy.py"

View File

@@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
]
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader(
@@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform)
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
with torch.inference_mode():
action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, dataset.transform)
# convert action to cpu numpy array
action = postprocess_action(action)
# Test step through policy
env.step(action)