WIP
This commit is contained in:
@@ -110,6 +110,16 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
|
||||
@@ -106,6 +106,23 @@ def make_offline_buffer(
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if (
|
||||
cfg.env.name == "aloha"
|
||||
and cfg.env.task == "sim_transfer_cube_scripted"
|
||||
and cfg.policy.name == "act"
|
||||
):
|
||||
import pickle
|
||||
|
||||
with open(
|
||||
"/home/rcadene/code/act/tmp/2024_03_10_sim_transfer_cube_scripted/dataset_stats.pkl", "rb"
|
||||
) as file:
|
||||
dataset_stats = pickle.load(file)
|
||||
|
||||
stats["action", "mean"] = torch.from_numpy(dataset_stats["action_mean"])[None, :]
|
||||
stats["action", "std"] = torch.from_numpy(dataset_stats["action_std"])[None, :]
|
||||
stats["observation", "state", "mean"] = torch.from_numpy(dataset_stats["qpos_mean"])[None, :]
|
||||
stats["observation", "state", "std"] = torch.from_numpy(dataset_stats["qpos_std"])[None, :]
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)
|
||||
|
||||
Reference in New Issue
Block a user