This commit is contained in:
Cadene
2024-03-11 14:06:34 +00:00
parent bdd2c801bc
commit 706453ac17
4 changed files with 37 additions and 1 deletions

View File

@@ -110,6 +110,16 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path)
return stats
@property
def stats_patterns(self) -> dict:
d = {

View File

@@ -106,6 +106,23 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if (
cfg.env.name == "aloha"
and cfg.env.task == "sim_transfer_cube_scripted"
and cfg.policy.name == "act"
):
import pickle
with open(
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
) as file:
dataset_stats = pickle.load(file)
stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])[None, :]
stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])[None, :]
stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])[None, :]
stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])[None, :]
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)

View File

@@ -37,4 +37,12 @@ def make_policy(cfg):
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted" and cfg.policy.name == "act":
import torch
state_dict = torch.load(
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt"
)
policy.load_state_dict(state_dict)
return policy

View File

@@ -121,7 +121,8 @@ def eval(cfg: dict, out_dir=None):
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
# if cfg.policy.pretrained_model_path:
if True:
policy = make_policy(cfg)
policy = TensorDictModule(
policy,