Compare commits
1 Commits
main
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10fc36a572 |
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
@@ -62,6 +63,19 @@ class DiffusionPolicy(nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ckpt = torch.load(
|
||||
"/admin/home/remi_cadene/code/lerobot/outputs/diffusion_policy/experiments/pusht/policy_cnn_train_0/checkpoints/pusht_vision_100ep.ckpt"
|
||||
)
|
||||
ckpt_image = OrderedDict()
|
||||
ckpt_noise = OrderedDict()
|
||||
for key in ckpt:
|
||||
if "vision_encoder." in key:
|
||||
ckpt_image[key.replace("vision_encoder.", "")] = ckpt[key]
|
||||
if "noise_pred_net." in key:
|
||||
ckpt_noise[key.replace("noise_pred_net.", "")] = ckpt[key]
|
||||
self.diffusion.obs_encoder.key_model_map.image.load_state_dict(ckpt_image)
|
||||
self.diffusion.model.load_state_dict(ckpt_noise)
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.diffusion.cuda()
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ def eval(cfg: dict, out_dir=None):
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# if cfg.policy.pretrained_model_path:
|
||||
if True:
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
|
||||
Reference in New Issue
Block a user