Compare commits

...

1 Commits

Author SHA1 Message Date
Cadene
b017bdb00d WIP 2024-03-31 08:44:09 +00:00
6 changed files with 38 additions and 7 deletions

View File

@@ -123,6 +123,23 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
# if (
# cfg.env.name == "aloha"
# and cfg.env.task == "sim_transfer_cube"
# and cfg.policy.name == "act"
# ):
# import pickle
# with open(
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
# ) as file:
# dataset_stats = pickle.load(file)
# stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])
# stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])
# stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])
# stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))

View File

@@ -138,6 +138,9 @@ class Base(robot_env.MujocoRobotEnv):
# HACK
self.model.vis.global_.offwidth = width
self.model.vis.global_.offheight = height
if mode in self.mujoco_renderer._viewers:
self.mujoco_renderer._viewers.get(mode).viewport.width = width
self.mujoco_renderer._viewers.get(mode).viewport.height = height
return self.mujoco_renderer.render(mode)
def close(self):

View File

@@ -26,15 +26,22 @@ def make_policy(cfg):
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
# if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted":
# import torch
# state_dict = torch.load(
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt"
# )
# policy.load_state_dict(state_dict)
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()

View File

@@ -152,7 +152,7 @@ def enc(cfg):
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
NormalizeImg(),
nn.Identity(), # NormalizeImg(), # TODO(rcadene): we need to clean this
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),

View File

@@ -86,6 +86,8 @@ def eval_policy(
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
# TODO(rcadene): set width and height from config or env
# ep_frames.append(env.render(width=384, height=384)) # noqa: B023
ep_frames.append(env.render()) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
@@ -266,7 +268,8 @@ if __name__ == "__main__":
)
cfg = hydra.compose(Path(args.config).stem, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
# stats_path = None
stats_path = "/home/rcadene/code/lerobot/outputs/pretrained_models/act_aloha_sim_transfer_cube_scripted/stats.pth"
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))

View File

@@ -172,7 +172,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0:
# if step % cfg.eval_freq == 0:
if True:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
@@ -187,8 +188,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0:
if True:
# if cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")