Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)

This commit is contained in:
Cadene
2024-03-10 10:52:12 +00:00
parent f1230cdac0
commit b49f7b70e2
11 changed files with 577 additions and 388 deletions

View File

@@ -27,7 +27,7 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@@ -43,17 +43,18 @@ class DETRVAE(nn.Module):
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
# TODO(rcadene): understand what is env_state, and why it needs to be 7
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
@@ -180,8 +181,6 @@ def build_encoder(args):
def build(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
@@ -197,7 +196,8 @@ def build(args):
backbones,
transformer,
encoder,
state_dim=state_dim,
state_dim=args.state_dim,
action_dim=args.action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
)

View File

@@ -25,29 +25,6 @@ def build_act_model_and_optimizer(cfg):
return model, optimizer
# def build_CNNMLP_model_and_optimizer(cfg):
# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
# args = parser.parse_args()
# for k, v in cfg.items():
# setattr(args, k, v)
# model = build_CNNMLP_model(args)
# model.cuda()
# param_dicts = [
# {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
# {
# "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
# "lr": args.lr_backbone,
# },
# ]
# optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
# weight_decay=args.weight_decay)
# return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
@@ -65,9 +42,10 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg, device):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
@@ -179,11 +157,34 @@ class ActionChunkingTransformerPolicy(nn.Module):
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image"] = observation["image"].unsqueeze(1)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):
@@ -209,46 +210,3 @@ class ActionChunkingTransformerPolicy(nn.Module):
else:
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return action
# class CNNMLPPolicy(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# model, optimizer = build_CNNMLP_model_and_optimizer(cfg)
# self.model = model # decoder
# self.optimizer = optimizer
# def __call__(self, qpos, image, actions=None, is_pad=None):
# env_state = None # TODO
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# image = normalize(image)
# if actions is not None: # training time
# actions = actions[:, 0]
# a_hat = self.model(qpos, image, env_state, actions)
# mse = F.mse_loss(actions, a_hat)
# loss_dict = dict()
# loss_dict['mse'] = mse
# loss_dict['loss'] = loss_dict['mse']
# return loss_dict
# else: # inference time
# a_hat = self.model(qpos, image, env_state) # no action, sample from prior
# return a_hat
# def configure_optimizers(self):
# return self.optimizer
# def kl_divergence(mu, logvar):
# batch_size = mu.size(0)
# assert batch_size != 0
# if mu.data.ndimension() == 4:
# mu = mu.view(mu.size(0), mu.size(1))
# if logvar.data.ndimension() == 4:
# logvar = logvar.view(logvar.size(0), logvar.size(1))
# klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
# total_kld = klds.sum(1).mean(0, True)
# dimension_wise_kld = klds.mean(0)
# mean_kld = klds.mean(1).mean(0, True)
# return total_kld, dimension_wise_kld, mean_kld

View File

@@ -20,7 +20,9 @@ def make_policy(cfg):
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
else:
raise ValueError(cfg.policy.name)