forked from tangger/lerobot
Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
@@ -64,8 +63,6 @@ def eval_policy(
|
||||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
transform: callable = None,
|
||||
return_episode_data: bool = False,
|
||||
seed=None,
|
||||
):
|
||||
@@ -132,10 +129,6 @@ def eval_policy(
|
||||
if return_episode_data:
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
# apply transform to normalize the observations
|
||||
for key in observation:
|
||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
@@ -143,8 +136,8 @@ def eval_policy(
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
@@ -360,7 +353,7 @@ def eval_policy(
|
||||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
@@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
transform=transform,
|
||||
return_episode_data=False,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
@@ -423,17 +411,13 @@ if __name__ == "__main__":
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user