forked from tangger/lerobot
WIP
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
@@ -62,6 +63,19 @@ class DiffusionPolicy(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ckpt = torch.load(
|
||||||
|
"/admin/home/remi_cadene/code/lerobot/outputs/diffusion_policy/experiments/pusht/policy_cnn_train_0/checkpoints/pusht_vision_100ep.ckpt"
|
||||||
|
)
|
||||||
|
ckpt_image = OrderedDict()
|
||||||
|
ckpt_noise = OrderedDict()
|
||||||
|
for key in ckpt:
|
||||||
|
if "vision_encoder." in key:
|
||||||
|
ckpt_image[key.replace("vision_encoder.", "")] = ckpt[key]
|
||||||
|
if "noise_pred_net." in key:
|
||||||
|
ckpt_noise[key.replace("noise_pred_net.", "")] = ckpt[key]
|
||||||
|
self.diffusion.obs_encoder.key_model_map.image.load_state_dict(ckpt_image)
|
||||||
|
self.diffusion.model.load_state_dict(ckpt_noise)
|
||||||
|
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.diffusion.cuda()
|
self.diffusion.cuda()
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,8 @@ def eval(cfg: dict, out_dir=None):
|
|||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer._transform)
|
env = make_env(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
# if cfg.policy.pretrained_model_path:
|
||||||
|
if True:
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
policy = TensorDictModule(
|
policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
|
|||||||
Reference in New Issue
Block a user