Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@@ -64,8 +63,6 @@ def eval_policy(
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False,
seed=None,
):
@@ -132,10 +129,6 @@ def eval_policy(
if return_episode_data:
observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@@ -143,8 +136,8 @@ def eval_policy(
with torch.inference_mode():
action = policy.select_action(observation, step=step)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
# convert to cpu numpy
action = postprocess_action(action)
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
@@ -360,7 +353,7 @@ def eval_policy(
return info
def eval(cfg: dict, out_dir=None, stats_path=None):
def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
@@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False,
seed=cfg.seed,
)
@@ -423,17 +411,13 @@ if __name__ == "__main__":
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)