Compare commits

...

1 Commits

Author SHA1 Message Date
Remi Cadene
10fc36a572 WIP 2024-03-04 17:28:11 +00:00
2 changed files with 16 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import copy
import time
from collections import OrderedDict
import hydra
import torch
@@ -62,6 +63,19 @@ class DiffusionPolicy(nn.Module):
**kwargs,
)
ckpt = torch.load(
"/admin/home/remi_cadene/code/lerobot/outputs/diffusion_policy/experiments/pusht/policy_cnn_train_0/checkpoints/pusht_vision_100ep.ckpt"
)
ckpt_image = OrderedDict()
ckpt_noise = OrderedDict()
for key in ckpt:
if "vision_encoder." in key:
ckpt_image[key.replace("vision_encoder.", "")] = ckpt[key]
if "noise_pred_net." in key:
ckpt_noise[key.replace("noise_pred_net.", "")] = ckpt[key]
self.diffusion.obs_encoder.key_model_map.image.load_state_dict(ckpt_image)
self.diffusion.model.load_state_dict(ckpt_noise)
self.device = torch.device("cuda")
self.diffusion.cuda()

View File

@@ -120,7 +120,8 @@ def eval(cfg: dict, out_dir=None):
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
if cfg.policy.pretrained_model_path:
# if cfg.policy.pretrained_model_path:
if True:
policy = make_policy(cfg)
policy = TensorDictModule(
policy,