forked from tangger/lerobot
Compare commits
1 Commits
pre-commit
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b017bdb00d |
@@ -123,6 +123,23 @@ def make_offline_buffer(
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# if (
|
||||
# cfg.env.name == "aloha"
|
||||
# and cfg.env.task == "sim_transfer_cube"
|
||||
# and cfg.policy.name == "act"
|
||||
# ):
|
||||
# import pickle
|
||||
|
||||
# with open(
|
||||
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
|
||||
# ) as file:
|
||||
# dataset_stats = pickle.load(file)
|
||||
|
||||
# stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])
|
||||
# stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])
|
||||
# stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])
|
||||
# stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
@@ -138,6 +138,9 @@ class Base(robot_env.MujocoRobotEnv):
|
||||
# HACK
|
||||
self.model.vis.global_.offwidth = width
|
||||
self.model.vis.global_.offheight = height
|
||||
if mode in self.mujoco_renderer._viewers:
|
||||
self.mujoco_renderer._viewers.get(mode).viewport.width = width
|
||||
self.mujoco_renderer._viewers.get(mode).viewport.height = height
|
||||
return self.mujoco_renderer.render(mode)
|
||||
|
||||
def close(self):
|
||||
|
||||
@@ -26,15 +26,22 @@ def make_policy(cfg):
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
|
||||
# if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted":
|
||||
# import torch
|
||||
# state_dict = torch.load(
|
||||
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt"
|
||||
# )
|
||||
# policy.load_state_dict(state_dict)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
if "offline" in cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 25000
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
elif "final" in cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -152,7 +152,7 @@ def enc(cfg):
|
||||
if cfg.modality in {"pixels", "all"}:
|
||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||
pixels_enc_layers = [
|
||||
NormalizeImg(),
|
||||
nn.Identity(), # NormalizeImg(), # TODO(rcadene): we need to clean this
|
||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
||||
|
||||
@@ -86,6 +86,8 @@ def eval_policy(
|
||||
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
# TODO(rcadene): set width and height from config or env
|
||||
# ep_frames.append(env.render(width=384, height=384)) # noqa: B023
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
# Clear the policy's action queue before the start of a new rollout.
|
||||
@@ -266,7 +268,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
# stats_path = None
|
||||
stats_path = "/home/rcadene/code/lerobot/outputs/pretrained_models/act_aloha_sim_transfer_cube_scripted/stats.pth"
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||
|
||||
@@ -172,7 +172,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def _maybe_eval_and_maybe_save(step):
|
||||
if step % cfg.eval_freq == 0:
|
||||
# if step % cfg.eval_freq == 0:
|
||||
if True:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
@@ -187,8 +188,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.save_model and step % cfg.save_freq == 0:
|
||||
if True:
|
||||
# if cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
Reference in New Issue
Block a user