ready for review
This commit is contained in:
@@ -16,6 +16,7 @@ def make_policy(cfg):
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
@@ -39,23 +40,4 @@ def make_policy(cfg):
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
# import torch
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
|
||||
# aligned = {}
|
||||
|
||||
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
# their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.out"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
|
||||
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
|
||||
# assert all('_dummy_variable' in k for k in missing_keys)
|
||||
# assert len(unexpected_keys) == 0
|
||||
|
||||
return policy
|
||||
|
||||
Reference in New Issue
Block a user