backup wip

This commit is contained in:
Alexander Soare
2024-04-05 18:46:30 +01:00
parent ecc7dd3b17
commit 8d2463f45b
5 changed files with 105 additions and 120 deletions

View File

@@ -28,22 +28,23 @@ for to_remove in start_removals:
# Replace keys based on what they start with.
start_replacements = [
("model.query_embed.weight", "model.pos_embed.weight"),
("model.pos_table", "model.vae_encoder_pos_enc"),
("model.pos_embed.weight", "model.decoder_pos_embed.weight"),
("model.encoder.", "model.vae_encoder."),
("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."),
("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."),
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
("model.input_proj.", "model.encoder_img_feat_input_proj."),
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
("model.transformer.encoder.", "model.encoder."),
("model.transformer.decoder.", "model.decoder."),
("model.backbones.0.0.body.", "model.backbone."),
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
("model.", ""),
("query_embed.weight", "pos_embed.weight"),
("pos_table", "vae_encoder_pos_enc"),
("pos_embed.weight", "decoder_pos_embed.weight"),
("encoder.", "vae_encoder."),
("encoder_action_proj.", "vae_encoder_action_input_proj."),
("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."),
("latent_proj.", "vae_encoder_latent_output_proj."),
("latent_proj.", "vae_encoder_latent_output_proj."),
("input_proj.", "encoder_img_feat_input_proj."),
("input_proj_robot_state", "encoder_robot_state_input_proj"),
("latent_out_proj.", "encoder_latent_input_proj."),
("transformer.encoder.", "encoder."),
("transformer.decoder.", "decoder."),
("backbones.0.0.body.", "backbone."),
("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"),
("cls_embed.weight", "vae_encoder_cls_embed.weight"),
]
for to_replace, replace_with in start_replacements: