This commit is contained in:
Thomas Wolf
2024-06-14 10:19:16 +02:00
parent c108bfe840
commit 594acbf136
6 changed files with 57 additions and 78 deletions

View File

@@ -1,14 +0,0 @@
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0.0044, device='cuda:0', grad_fn=<MaxBackward1>)
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt').to('cpu')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt').to('cpu')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0., grad_fn=<MaxBackward1>)
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
torch.save(out, '/home/thomwolf/Documents/Github/ACT/tensor_out_lerobot.pt')

View File

@@ -754,7 +754,6 @@
" 'model.transformer.decoder.layers.4.',\n",
" 'model.transformer.decoder.layers.5.',\n",
" 'model.transformer.decoder.layers.6.',\n",
" 'model.transformer.decoder.norm.',\n",
" 'model.is_pad_head']\n",
"\n",
"to_remove_in = ['num_batches_tracked',]\n",
@@ -773,6 +772,8 @@
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.transformer.decoder.norm.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.encoder.layers.'):\n",
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
" if k.startswith('model.action_head.'):\n",
@@ -1008,7 +1009,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
"version": "3.10.14"
}
},
"nbformat": 4,

View File

@@ -48,6 +48,7 @@ training:
eval:
n_episodes: 1
batch_size: 1
max_episodes_rendered: 0
# See `configuration_act.py` for more details.
policy:

View File

@@ -131,7 +131,7 @@ class ACTConfig:
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
# As a consequence we also remove the final, unused layer normalization, by default
n_decoder_layers: int = 1
decoder_norm: bool = False
decoder_norm: bool = True
# VAE.
use_vae: bool = True
latent_dim: int = 32

View File

@@ -238,11 +238,10 @@ class ACT(nn.Module):
# map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
# TODO thom fix this
# self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
)
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# self.backbone = IntermediateLayerGetter(
# backbone_model, return_layers={"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
# )
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -302,7 +301,7 @@ class ACT(nn.Module):
batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder.
if False: ###### TODO(thom) remove this self.config.use_vae and "action" in batch:
if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@@ -356,7 +355,7 @@ class ACT(nn.Module):
for cam_index in range(images.shape[-4]):
torch.backends.cudnn.deterministic = True
cam_features = self.backbone(images[:, cam_index])
cam_features = cam_features[3]
cam_features = cam_features["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)

View File

@@ -44,27 +44,13 @@ from lerobot.common.utils.utils import (
)
from lerobot.scripts.eval import eval_policy
################## TODO remove this part
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
##################
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{"params": [p for n, p in policy.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"params": [p for n, p in policy.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.training.lr_backbone,
},
]
@@ -107,57 +93,63 @@ def update_policy(
use_amp: bool = False,
):
"""Returns a dictionary of items for logging."""
################## TODO remove this part
torch.backends.cudnn.deterministic = True
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
##################
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
################## TODO remove this part
pretrained_policy_name_or_path = (
"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state"
)
from lerobot.common.policies.act.modeling_act import ACTPolicy
# ################## TODO remove this part
# pretrained_policy_name_or_path = (
# "/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state"
# )
# from lerobot.common.policies.act.modeling_act import ACTPolicy
policy_cls = ACTPolicy
policy_cfg = policy.config
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(device)
# policy_cls = ACTPolicy
# policy_cfg = policy.config
# policy = policy_cls(policy_cfg)
# policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
# policy.to(device)
policy.eval() # No dropout
##################
# policy.eval() # No dropout
# ##################
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
########################### TODO remove this part
batch = torch.load("/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt", map_location=device)
# ########################### TODO remove this part
# batch = torch.load("/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt", map_location=device)
# print some stats
def model_stats(model):
na = [n for n, a in model.named_parameters() if "normalize_" not in n]
me = [a.mean().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[me.index(min(me))], min(me))
print(sum(me))
mi = [a.min().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[mi.index(min(mi))], min(mi))
print(sum(mi))
ma = [a.max().item() for n, a in model.named_parameters() if "normalize_" not in n]
print(na[ma.index(max(ma))], max(ma))
print(sum(ma))
# # print some stats
# def model_stats(model):
# na = [n for n, a in model.named_parameters() if "normalize_" not in n]
# me = [a.mean().item() for n, a in model.named_parameters() if "normalize_" not in n]
# print(na[me.index(min(me))], min(me))
# print(sum(me))
# mi = [a.min().item() for n, a in model.named_parameters() if "normalize_" not in n]
# print(na[mi.index(min(mi))], min(mi))
# print(sum(mi))
# ma = [a.max().item() for n, a in model.named_parameters() if "normalize_" not in n]
# print(na[ma.index(max(ma))], max(ma))
# print(sum(ma))
model_stats(policy)
# model_stats(policy)
def batch_stats(data):
print(min(d.min() for d in data))
print(max(d.max() for d in data))
# def batch_stats(data):
# print(min(d.min() for d in data))
# print(max(d.max() for d in data))
data = (
batch["observation.images.front"],
batch["observation.images.top"],
batch["observation.state"],
batch["action"],
)
batch_stats(data)
# data = (
# batch["observation.images.front"],
# batch["observation.images.top"],
# batch["observation.state"],
# batch["action"],
# )
# batch_stats(data)
###########################
# ###########################
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)