WIP
This commit is contained in:
@@ -110,6 +110,16 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
|
||||
@@ -106,6 +106,23 @@ def make_offline_buffer(
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if (
|
||||
cfg.env.name == "aloha"
|
||||
and cfg.env.task == "sim_transfer_cube_scripted"
|
||||
and cfg.policy.name == "act"
|
||||
):
|
||||
import pickle
|
||||
|
||||
with open(
|
||||
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
|
||||
) as file:
|
||||
dataset_stats = pickle.load(file)
|
||||
|
||||
stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])[None, :]
|
||||
stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])[None, :]
|
||||
stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])[None, :]
|
||||
stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])[None, :]
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)
|
||||
|
||||
@@ -37,4 +37,12 @@ def make_policy(cfg):
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
if cfg.env.name == "aloha" and cfg.env.task == "sim_transfer_cube_scripted" and cfg.policy.name == "act":
|
||||
import torch
|
||||
|
||||
state_dict = torch.load(
|
||||
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/policy_best.ckpt"
|
||||
)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -121,7 +121,8 @@ def eval(cfg: dict, out_dir=None):
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# if cfg.policy.pretrained_model_path:
|
||||
if True:
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
|
||||
Reference in New Issue
Block a user