From 10fc36a572a09615f088c5abe9230754660963f1 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 4 Mar 2024 17:28:11 +0000 Subject: [PATCH] WIP --- lerobot/common/policies/diffusion/policy.py | 14 ++++++++++++++ lerobot/scripts/eval.py | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 7ae0a529d..5ac52005c 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,5 +1,6 @@ import copy import time +from collections import OrderedDict import hydra import torch @@ -62,6 +63,19 @@ class DiffusionPolicy(nn.Module): **kwargs, ) + ckpt = torch.load( + "/admin/home/remi_cadene/code/lerobot/outputs/diffusion_policy/experiments/pusht/policy_cnn_train_0/checkpoints/pusht_vision_100ep.ckpt" + ) + ckpt_image = OrderedDict() + ckpt_noise = OrderedDict() + for key in ckpt: + if "vision_encoder." in key: + ckpt_image[key.replace("vision_encoder.", "")] = ckpt[key] + if "noise_pred_net." in key: + ckpt_noise[key.replace("noise_pred_net.", "")] = ckpt[key] + self.diffusion.obs_encoder.key_model_map.image.load_state_dict(ckpt_image) + self.diffusion.model.load_state_dict(ckpt_noise) + self.device = torch.device("cuda") self.diffusion.cuda() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index abe4645a8..5321d4921 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -120,7 +120,8 @@ def eval(cfg: dict, out_dir=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer._transform) - if cfg.policy.pretrained_model_path: + # if cfg.policy.pretrained_model_path: + if True: policy = make_policy(cfg) policy = TensorDictModule( policy,