ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -16,6 +16,7 @@ def make_policy(cfg):
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
@@ -39,23 +40,4 @@ def make_policy(cfg):
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
# import torch
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
# aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone"
# our_prefix = "obs_encoder.key_model_map.image.backbone"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.pool"
# our_prefix = "obs_encoder.key_model_map.image.pool"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
# our_prefix = "obs_encoder.key_model_map.image.out"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
# assert all('_dummy_variable' in k for k in missing_keys)
# assert len(unexpected_keys) == 0
return policy