diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 3b53fed1..16c0d54c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -110,6 +110,16 @@ class AlohaExperienceReplay(AbstractExperienceReplay): transform=transform, ) + def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + stats_path = self.data_dir / "stats.pth" + if stats_path.exists(): + stats = torch.load(stats_path) + else: + logging.info(f"compute_stats and save to {stats_path}") + stats = self._compute_stats(num_batch, batch_size) + torch.save(stats, stats_path) + return stats + @property def stats_patterns(self) -> dict: d = { diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 63bde225..e59c96dc 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -106,6 +106,23 @@ def make_offline_buffer( stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + if ( + cfg.env.name == "aloha" + and cfg.env.task == "sim_transfer_cube_scripted" + and cfg.policy.name == "act" + ): + import pickle + + with open( + "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb" + ) as file: + dataset_stats = pickle.load(file) + + stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])[None, :] + stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])[None, :] + stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])[None, :] + stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])[None, :] + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" transform = NormalizeTransform(stats, in_keys, mode=normalization_mode) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index c5e45300..5425bf1b 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -37,4 +37,12 @@ def make_policy(cfg): raise NotImplementedError() policy.load(cfg.policy.pretrained_model_path) + if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted" and cfg.policy.name == "act": + import torch + + state_dict = torch.load( + "/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt" + ) + policy.load_state_dict(state_dict) + return policy diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7ba2812e..ad2eea34 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -121,7 +121,8 @@ def eval(cfg: dict, out_dir=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) - if cfg.policy.pretrained_model_path: + # if cfg.policy.pretrained_model_path: + if True: policy = make_policy(cfg) policy = TensorDictModule( policy,