From b017bdb00df60b4e5ab31999483e739a8c19c349 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 31 Mar 2024 08:44:09 +0000 Subject: [PATCH] WIP --- lerobot/common/datasets/factory.py | 17 +++++++++++++++++ .../common/envs/simxarm/simxarm/tasks/base.py | 3 +++ lerobot/common/policies/factory.py | 11 +++++++++-- lerobot/common/policies/tdmpc/helper.py | 2 +- lerobot/scripts/eval.py | 5 ++++- lerobot/scripts/train.py | 7 ++++--- 6 files changed, 38 insertions(+), 7 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 04077034f..2f96c8339 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -123,6 +123,23 @@ def make_offline_buffer( stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + # if ( + # cfg.env.name == "aloha" + # and cfg.env.task == "sim_transfer_cube" + # and cfg.policy.name == "act" + # ): + # import pickle + + # with open( + # "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb" + # ) as file: + # dataset_stats = pickle.load(file) + + # stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"]) + # stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"]) + # stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"]) + # stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"]) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py index b937b2901..11ace5800 100644 --- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py +++ b/lerobot/common/envs/simxarm/simxarm/tasks/base.py @@ -138,6 +138,9 @@ class Base(robot_env.MujocoRobotEnv): # HACK self.model.vis.global_.offwidth = width self.model.vis.global_.offheight = height + if mode in self.mujoco_renderer._viewers: + self.mujoco_renderer._viewers.get(mode).viewport.width = width + self.mujoco_renderer._viewers.get(mode).viewport.height = height return self.mujoco_renderer.render(mode) def close(self): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 934f09629..6a4759377 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -26,15 +26,22 @@ def make_policy(cfg): policy = ActionChunkingTransformerPolicy( cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps ) + + # if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted": + # import torch + # state_dict = torch.load( + # "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt" + # ) + # policy.load_state_dict(state_dict) else: raise ValueError(cfg.policy.name) if cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: + if "offline" in cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: + elif "final" in cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() diff --git a/lerobot/common/policies/tdmpc/helper.py b/lerobot/common/policies/tdmpc/helper.py index 964f17184..e345be9f1 100644 --- a/lerobot/common/policies/tdmpc/helper.py +++ b/lerobot/common/policies/tdmpc/helper.py @@ -152,7 +152,7 @@ def enc(cfg): if cfg.modality in {"pixels", "all"}: C = int(3 * cfg.frame_stack) # noqa: N806 pixels_enc_layers = [ - NormalizeImg(), + nn.Identity(), # NormalizeImg(), # TODO(rcadene): we need to clean this nn.Conv2d(C, cfg.num_channels, 7, stride=2), nn.ReLU(), nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 28a25e43c..09c0f2966 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -86,6 +86,8 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 + # TODO(rcadene): set width and height from config or env + # ep_frames.append(env.render(width=384, height=384)) # noqa: B023 ep_frames.append(env.render()) # noqa: B023 # Clear the policy's action queue before the start of a new rollout. @@ -266,7 +268,8 @@ if __name__ == "__main__": ) cfg = hydra.compose(Path(args.config).stem, args.overrides) # TODO(alexander-soare): Save and load stats in trained model directory. - stats_path = None + # stats_path = None + stats_path = "/home/rcadene/code/lerobot/outputs/pretrained_models/act_aloha_sim_transfer_cube_scripted/stats.pth" elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision=args.revision)) cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 18c3715bf..dfee0f830 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -172,7 +172,8 @@ def train(cfg: dict, out_dir=None, job_name=None): # Note: this helper will be used in offline and online training loops. def _maybe_eval_and_maybe_save(step): - if step % cfg.eval_freq == 0: + # if step % cfg.eval_freq == 0: + if True: logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, @@ -187,8 +188,8 @@ def train(cfg: dict, out_dir=None, job_name=None): if cfg.wandb.enable: logger.log_video(first_video, step, mode="eval") logging.info("Resume training") - - if cfg.save_model and step % cfg.save_freq == 0: + if True: + # if cfg.save_model and step % cfg.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") logger.save_model(policy, identifier=step) logging.info("Resume training")