save state

This commit is contained in:
Thomas Wolf
2024-06-12 20:46:26 +02:00
parent a7c030076f
commit c108bfe840
6 changed files with 385 additions and 39 deletions

View File

@@ -0,0 +1,14 @@
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0.0044, device='cuda:0', grad_fn=<MaxBackward1>)
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt').to('cpu')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt').to('cpu')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0., grad_fn=<MaxBackward1>)
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
torch.save(out, '/home/thomwolf/Documents/Github/ACT/tensor_out_lerobot.pt')

View File

@@ -0,0 +1,92 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pprint import pprint\n",
"import pickle\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n",
"data = torch.load(original_batch_file)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"#orig: image_data, qpos_data, action_data, is_pad\n",
"#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"conv = {}\n",
"conv['observation.images.front'] = data[0][:, 0]\n",
"conv['observation.images.top'] = data[0][:, 1]\n",
"conv['observation.state'] = data[1]\n",
"conv['action'] = data[2]\n",
"conv['episode_index'] = np.zeros(data[0].shape[0])\n",
"conv['frame_index'] = np.zeros(data[0].shape[0])\n",
"conv['timestamp'] = np.zeros(data[0].shape[0])\n",
"conv['next.done'] = np.zeros(data[0].shape[0])\n",
"conv['index'] = np.arange(data[0].shape[0])\n",
"conv['action_is_pad'] = data[3]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@@ -32,7 +32,6 @@ import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
@@ -75,7 +74,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.expected_image_keys = [
k for k in sorted(config.input_shapes) if k.startswith("observation.image")
]
self.reset()
@@ -135,7 +136,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[k] for k in sorted(self.expected_image_keys)], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -228,13 +231,18 @@ class ACT(nn.Module):
# Backbone for image feature extraction.
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
weights="DEFAULT", # config.pretrained_backbone_weights,
# norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
# map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# TODO thom fix this
# self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -294,7 +302,7 @@ class ACT(nn.Module):
batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
if False: ###### TODO(thom) remove this self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@@ -346,7 +354,9 @@ class ACT(nn.Module):
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
torch.backends.cudnn.deterministic = True
cam_features = self.backbone(images[:, cam_index])
cam_features = cam_features[3]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)

View File

@@ -140,14 +140,14 @@ class Normalize(nn.Module):
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
batch[key].sub_(mean).div_(std)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:

View File

@@ -44,6 +44,12 @@ from lerobot.common.utils.utils import (
)
from lerobot.scripts.eval import eval_policy
################## TODO remove this part
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
##################
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
@@ -104,7 +110,55 @@ def update_policy(
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
################## TODO remove this part
pretrained_policy_name_or_path = (
"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state"
)
from lerobot.common.policies.act.modeling_act import ACTPolicy
policy_cls = ACTPolicy
policy_cfg = policy.config
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(device)
policy.eval() # No dropout
##################
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
########################### TODO remove this part
batch = torch.load("/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt", map_location=device)
# print some stats
def model_stats(model):
na = [n for n, a in model.named_parameters() if "normalize_" not in n]
me = [a.mean().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[me.index(min(me))], min(me))
print(sum(me))
mi = [a.min().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[mi.index(min(mi))], min(mi))
print(sum(mi))
ma = [a.max().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[ma.index(max(ma))], max(ma))
print(sum(ma))
model_stats(policy)
def batch_stats(data):
print(min(d.min() for d in data))
print(max(d.max() for d in data))
data = (
batch["observation.images.front"],
batch["observation.images.top"],
batch["observation.state"],
batch["action"],
)
batch_stats(data)
###########################
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]