Files
lerobot/lerobot/common/policies/act/policy.py
2024-03-08 16:54:43 +00:00

139 lines
4.9 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.act.detr_vae import build
def build_act_model_and_optimizer(cfg):
model = build(cfg)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
return model, optimizer
# def build_CNNMLP_model_and_optimizer(cfg):
# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
# args = parser.parse_args()
# for k, v in cfg.items():
# setattr(args, k, v)
# model = build_CNNMLP_model(args)
# model.cuda()
# param_dicts = [
# {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
# {
# "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
# "lr": args.lr_backbone,
# },
# ]
# optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
# weight_decay=args.weight_decay)
# return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class ACTPolicy(nn.Module):
def __init__(self, cfg):
super().__init__()
model, optimizer = build_act_model_and_optimizer(cfg)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = cfg.kl_weight
print(f"KL Weight {self.kl_weight}")
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, : self.model.num_queries]
is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = {}
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict["l1"] = l1
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
# class CNNMLPPolicy(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# model, optimizer = build_CNNMLP_model_and_optimizer(cfg)
# self.model = model # decoder
# self.optimizer = optimizer
# def __call__(self, qpos, image, actions=None, is_pad=None):
# env_state = None # TODO
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# image = normalize(image)
# if actions is not None: # training time
# actions = actions[:, 0]
# a_hat = self.model(qpos, image, env_state, actions)
# mse = F.mse_loss(actions, a_hat)
# loss_dict = dict()
# loss_dict['mse'] = mse
# loss_dict['loss'] = loss_dict['mse']
# return loss_dict
# else: # inference time
# a_hat = self.model(qpos, image, env_state) # no action, sample from prior
# return a_hat
# def configure_optimizers(self):
# return self.optimizer
# def kl_divergence(mu, logvar):
# batch_size = mu.size(0)
# assert batch_size != 0
# if mu.data.ndimension() == 4:
# mu = mu.view(mu.size(0), mu.size(1))
# if logvar.data.ndimension() == 4:
# logvar = logvar.view(logvar.size(0), logvar.size(1))
# klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
# total_kld = klds.sum(1).mean(0, True)
# dimension_wise_kld = klds.mean(0)
# mean_kld = klds.mean(1).mean(0, True)
# return total_kld, dimension_wise_kld, mean_kld