Move normalize/unnormalize transforms to policy for act and diffusion
This commit is contained in:
@@ -42,7 +42,7 @@ def test_factory(env_name):
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||
obs = preprocess_observation(obs)
|
||||
for key in dataset.image_keys:
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_examples_4_and_3():
|
||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/3_evaluate_pretrained_policy.py"
|
||||
|
||||
@@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, dataset.transform)
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
@@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, dataset.transform)
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
Reference in New Issue
Block a user