ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -1,3 +1,44 @@
"""Code from the original diffusion policy project.
Notes on how to load a checkpoint from the original repository:
In the original repository, run the eval and use a breakpoint to extract the policy weights.
```
torch.save(policy.state_dict(), "weights.pt")
```
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
```
loaded = torch.load("weights.pt")
aligned = {}
their_prefix = "obs_encoder.obs_nets.image.backbone"
our_prefix = "obs_encoder.key_model_map.image.backbone"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.pool"
our_prefix = "obs_encoder.key_model_map.image.pool"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.nets.3"
our_prefix = "obs_encoder.key_model_map.image.out"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# Note: here you are loading into the ema model.
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
assert all('_dummy_variable' in k for k in missing_keys)
assert len(unexpected_keys) == 0
```
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
"""
from typing import Dict
import torch