backup wip

This commit is contained in:
Alexander Soare
2024-04-04 18:34:41 +01:00
parent 278336a39a
commit 3a4dfa82fe
8 changed files with 538 additions and 1227 deletions

View File

@@ -11,6 +11,19 @@ policy = make_policy(cfg)
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
# Remove keys based on what they start with.
start_removals = [
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
"model.is_pad_head.",
]
for to_remove in start_removals:
for k in list(state_dict.keys()):
if k.startswith(to_remove):
del state_dict[k]
# Replace keys based on what they start with.
@@ -26,6 +39,9 @@ start_replacements = [
("model.input_proj.", "model.encoder_img_feat_input_proj."),
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
("model.transformer.encoder.", "model.encoder."),
("model.transformer.decoder.", "model.decoder."),
("model.backbones.0.0.body.", "model.backbone."),
]
for to_replace, replace_with in start_replacements:
@@ -35,18 +51,6 @@ for to_replace, replace_with in start_replacements:
state_dict[k_] = state_dict[k]
del state_dict[k]
# Remove keys based on what they start with.
start_removals = [
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
"model.is_pad_head.",
]
for to_remove in start_removals:
for k in list(state_dict.keys()):
if k.startswith(to_remove):
del state_dict[k]
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)