forked from tangger/lerobot
WIP
This commit is contained in:
@@ -123,6 +123,23 @@ def make_offline_buffer(
|
|||||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# cfg.env.name == "aloha"
|
||||||
|
# and cfg.env.task == "sim_transfer_cube"
|
||||||
|
# and cfg.policy.name == "act"
|
||||||
|
# ):
|
||||||
|
# import pickle
|
||||||
|
|
||||||
|
# with open(
|
||||||
|
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
|
||||||
|
# ) as file:
|
||||||
|
# dataset_stats = pickle.load(file)
|
||||||
|
|
||||||
|
# stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])
|
||||||
|
# stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])
|
||||||
|
# stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])
|
||||||
|
# stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])
|
||||||
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||||
|
|||||||
@@ -138,6 +138,9 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
# HACK
|
# HACK
|
||||||
self.model.vis.global_.offwidth = width
|
self.model.vis.global_.offwidth = width
|
||||||
self.model.vis.global_.offheight = height
|
self.model.vis.global_.offheight = height
|
||||||
|
if mode in self.mujoco_renderer._viewers:
|
||||||
|
self.mujoco_renderer._viewers.get(mode).viewport.width = width
|
||||||
|
self.mujoco_renderer._viewers.get(mode).viewport.height = height
|
||||||
return self.mujoco_renderer.render(mode)
|
return self.mujoco_renderer.render(mode)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -26,15 +26,22 @@ def make_policy(cfg):
|
|||||||
policy = ActionChunkingTransformerPolicy(
|
policy = ActionChunkingTransformerPolicy(
|
||||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted":
|
||||||
|
# import torch
|
||||||
|
# state_dict = torch.load(
|
||||||
|
# "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt"
|
||||||
|
# )
|
||||||
|
# policy.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 25000
|
policy.step[0] = 25000
|
||||||
elif "final" in cfg.pretrained_model_path:
|
elif "final" in cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ def enc(cfg):
|
|||||||
if cfg.modality in {"pixels", "all"}:
|
if cfg.modality in {"pixels", "all"}:
|
||||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||||
pixels_enc_layers = [
|
pixels_enc_layers = [
|
||||||
NormalizeImg(),
|
nn.Identity(), # NormalizeImg(), # TODO(rcadene): we need to clean this
|
||||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ def eval_policy(
|
|||||||
|
|
||||||
def maybe_render_frame(env: EnvBase, _):
|
def maybe_render_frame(env: EnvBase, _):
|
||||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||||
|
# TODO(rcadene): set width and height from config or env
|
||||||
|
# ep_frames.append(env.render(width=384, height=384)) # noqa: B023
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render()) # noqa: B023
|
||||||
|
|
||||||
# Clear the policy's action queue before the start of a new rollout.
|
# Clear the policy's action queue before the start of a new rollout.
|
||||||
@@ -266,7 +268,8 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
||||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||||
stats_path = None
|
# stats_path = None
|
||||||
|
stats_path = "/home/rcadene/code/lerobot/outputs/pretrained_models/act_aloha_sim_transfer_cube_scripted/stats.pth"
|
||||||
elif args.hub_id is not None:
|
elif args.hub_id is not None:
|
||||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||||
|
|||||||
@@ -172,7 +172,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def _maybe_eval_and_maybe_save(step):
|
def _maybe_eval_and_maybe_save(step):
|
||||||
if step % cfg.eval_freq == 0:
|
# if step % cfg.eval_freq == 0:
|
||||||
|
if True:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
@@ -187,8 +188,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
if True:
|
||||||
if cfg.save_model and step % cfg.save_freq == 0:
|
# if cfg.save_model and step % cfg.save_freq == 0:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
logger.save_model(policy, identifier=step)
|
logger.save_model(policy, identifier=step)
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|||||||
Reference in New Issue
Block a user