This commit is contained in:
Cadene
2024-03-11 14:06:34 +00:00
parent bdd2c801bc
commit 706453ac17
4 changed files with 37 additions and 1 deletions

View File

@@ -110,6 +110,16 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
transform=transform,
)
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path)
return stats
@property
def stats_patterns(self) -> dict:
d = {

View File

@@ -106,6 +106,23 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if (
cfg.env.name == "aloha"
and cfg.env.task == "sim_transfer_cube_scripted"
and cfg.policy.name == "act"
):
import pickle
with open(
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
) as file:
dataset_stats = pickle.load(file)
stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])[None, :]
stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])[None, :]
stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])[None, :]
stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])[None, :]
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)