backup wip
This commit is contained in:
@@ -2,7 +2,6 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from transformers import DetrForObjectDetection
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
@@ -74,7 +73,7 @@ class ActionChunkingTransformer(nn.Module):
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
# Positional embedding to be used as input to the latent vae_encoder (if applicable) and for the
|
||||
# Positional embedding to be used as input to the latent vae_encoder (if applicable) and for the
|
||||
self.pos_embed = nn.Embedding(horizon, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
@@ -134,7 +133,9 @@ class ActionChunkingTransformer(nn.Module):
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
vae_encoder_output = self.vae_encoder(vae_encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
vae_encoder_output = self.vae_encoder(
|
||||
vae_encoder_input, pos=pos_embed
|
||||
) # , src_key_padding_mask=is_pad)
|
||||
vae_encoder_output = vae_encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(vae_encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
@@ -219,7 +220,7 @@ def build(args):
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
|
||||
vae_encoder = build_vae_encoder(args)
|
||||
|
||||
model = ActionChunkingTransformer(
|
||||
|
||||
@@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
Args:
|
||||
vae: Whether to use the variational objective. TODO(now): Give more details.
|
||||
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
|
||||
returned as an exponential moving average of previously generated actions for that timestep.
|
||||
returned as an exponential moving average of previously generated actions for that timestep.
|
||||
n_obs_steps: Number of time steps worth of observation to use as input.
|
||||
horizon: The number of actions to generate in one forward pass.
|
||||
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
|
||||
@@ -120,7 +120,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
Reference in New Issue
Block a user