backup wip
This commit is contained in:
@@ -11,6 +11,19 @@ policy = make_policy(cfg)
|
||||
|
||||
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
|
||||
|
||||
# Remove keys based on what they start with.
|
||||
|
||||
start_removals = [
|
||||
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
||||
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
||||
"model.is_pad_head.",
|
||||
]
|
||||
|
||||
for to_remove in start_removals:
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(to_remove):
|
||||
del state_dict[k]
|
||||
|
||||
|
||||
# Replace keys based on what they start with.
|
||||
|
||||
@@ -26,6 +39,9 @@ start_replacements = [
|
||||
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
||||
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
||||
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
||||
("model.transformer.encoder.", "model.encoder."),
|
||||
("model.transformer.decoder.", "model.decoder."),
|
||||
("model.backbones.0.0.body.", "model.backbone."),
|
||||
]
|
||||
|
||||
for to_replace, replace_with in start_replacements:
|
||||
@@ -35,18 +51,6 @@ for to_replace, replace_with in start_replacements:
|
||||
state_dict[k_] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
# Remove keys based on what they start with.
|
||||
|
||||
start_removals = [
|
||||
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
||||
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
||||
"model.is_pad_head.",
|
||||
]
|
||||
|
||||
for to_remove in start_removals:
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(to_remove):
|
||||
del state_dict[k]
|
||||
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user