Integrate diffusion policy
This commit is contained in:
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
|
||||
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
):
|
||||
super().__init__()
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
|
||||
local_cond_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList(
|
||||
[
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, "b h t -> b t h")
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
# encode local features
|
||||
h_local = []
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b t h -> b h t")
|
||||
return x
|
||||
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch.nn as nn
|
||||
|
||||
# from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# def test():
|
||||
# cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
# x = torch.zeros((1,256,16))
|
||||
# o = cb(x)
|
||||
294
lerobot/common/policies/diffusion/model/crop_randomizer.py
Normal file
294
lerobot/common/policies/diffusion/model/crop_randomizer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
|
||||
import lerobot.common.policies.diffusion.model.tensor_utils as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape # noqa: N806
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = inputs.shape[0] // self.num_crops
|
||||
out = tu.reshape_dimensions(
|
||||
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
|
||||
)
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = "{}".format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width, self.num_crops
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(
|
||||
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None,) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DictOfTensorMixin(nn.Module):
|
||||
def __init__(self, params_dict=None):
|
||||
super().__init__()
|
||||
if params_dict is None:
|
||||
params_dict = nn.ParameterDict()
|
||||
self.params_dict = params_dict
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
def dfs_add(dest, keys, value: torch.Tensor):
|
||||
if len(keys) == 1:
|
||||
dest[keys[0]] = value
|
||||
return
|
||||
|
||||
if keys[0] not in dest:
|
||||
dest[keys[0]] = nn.ParameterDict()
|
||||
dfs_add(dest[keys[0]], keys[1:], value)
|
||||
|
||||
def load_dict(state_dict, prefix):
|
||||
out_dict = nn.ParameterDict()
|
||||
for key, value in state_dict.items():
|
||||
value: torch.Tensor
|
||||
if key.startswith(prefix):
|
||||
param_keys = key[len(prefix) :].split(".")[1:]
|
||||
# if len(param_keys) == 0:
|
||||
# import pdb; pdb.set_trace()
|
||||
dfs_add(out_dict, param_keys, value.clone())
|
||||
return out_dict
|
||||
|
||||
self.params_dict = load_dict(state_dict, prefix + "params_dict")
|
||||
self.params_dict.requires_grad_(False)
|
||||
return
|
||||
46
lerobot/common/policies/diffusion/model/lr_scheduler.py
Normal file
46
lerobot/common/policies/diffusion/model/lr_scheduler.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
|
||||
)
|
||||
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
|
||||
|
||||
class LowdimMaskGenerator(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
# obs mask setup
|
||||
max_n_obs_steps=2,
|
||||
fix_obs_steps=True,
|
||||
# action mask
|
||||
action_visible=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.max_n_obs_steps = max_n_obs_steps
|
||||
self.fix_obs_steps = fix_obs_steps
|
||||
self.action_visible = action_visible
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, shape, seed=None):
|
||||
device = self.device
|
||||
B, T, D = shape # noqa: N806
|
||||
assert (self.action_dim + self.obs_dim) == D
|
||||
|
||||
# create all tensors on this device
|
||||
rng = torch.Generator(device=device)
|
||||
if seed is not None:
|
||||
rng = rng.manual_seed(seed)
|
||||
|
||||
# generate dim mask
|
||||
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
||||
is_action_dim = dim_mask.clone()
|
||||
is_action_dim[..., : self.action_dim] = True
|
||||
is_obs_dim = ~is_action_dim
|
||||
|
||||
# generate obs mask
|
||||
if self.fix_obs_steps:
|
||||
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
||||
else:
|
||||
obs_steps = torch.randint(
|
||||
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
||||
)
|
||||
|
||||
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
||||
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
obs_mask = obs_mask & is_obs_dim
|
||||
|
||||
# generate action mask
|
||||
if self.action_visible:
|
||||
action_steps = torch.maximum(
|
||||
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
||||
)
|
||||
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
action_mask = action_mask & is_action_dim
|
||||
|
||||
mask = obs_mask
|
||||
if self.action_visible:
|
||||
mask = mask | action_mask
|
||||
|
||||
return mask
|
||||
15
lerobot/common/policies/diffusion/model/module_attr_mixin.py
Normal file
15
lerobot/common/policies/diffusion/model/module_attr_mixin.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -0,0 +1,189 @@
|
||||
import copy
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
rgb_keys = []
|
||||
low_dim_keys = []
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = {}
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map["rgb"] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta["obs"]
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
type = attr.get("type", "low_dim")
|
||||
key_shape_map[key] = shape
|
||||
if type == "rgb":
|
||||
rgb_keys.append(key)
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h, w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
|
||||
)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
# TODO(rcadene): move normalizer to dataset and env
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
elif type == "low_dim":
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
|
||||
def forward(self, obs_dict):
|
||||
batch_size = None
|
||||
features = []
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = []
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map["rgb"](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
features.append(feature)
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self):
|
||||
example_obs_dict = {}
|
||||
obs_shape_meta = self.shape_meta["obs"]
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
||||
358
lerobot/common/policies/diffusion/model/normalizer.py
Normal file
358
lerobot/common/policies/diffusion/model/normalizer.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import zarr
|
||||
|
||||
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class LinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self.params_dict[key] = _fit(
|
||||
value,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
else:
|
||||
self.params_dict["_default"] = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return SingleFieldLinearNormalizer(self.params_dict[key])
|
||||
|
||||
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
|
||||
self.params_dict[key] = value.params_dict
|
||||
|
||||
def _normalize_impl(self, x, forward=True):
|
||||
if isinstance(x, dict):
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
params = self.params_dict[key]
|
||||
result[key] = _normalize(value, params, forward=forward)
|
||||
return result
|
||||
else:
|
||||
if "_default" not in self.params_dict:
|
||||
raise RuntimeError("Not initialized")
|
||||
params = self.params_dict["_default"]
|
||||
return _normalize(x, params, forward=forward)
|
||||
|
||||
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=False)
|
||||
|
||||
def get_input_stats(self) -> Dict:
|
||||
if len(self.params_dict) == 0:
|
||||
raise RuntimeError("Not initialized")
|
||||
if len(self.params_dict) == 1 and "_default" in self.params_dict:
|
||||
return self.params_dict["_default"]["input_stats"]
|
||||
|
||||
result = {}
|
||||
for key, value in self.params_dict.items():
|
||||
if key != "_default":
|
||||
result[key] = value["input_stats"]
|
||||
return result
|
||||
|
||||
def get_output_stats(self, key="_default"):
|
||||
input_stats = self.get_input_stats()
|
||||
if "min" in input_stats:
|
||||
# no dict
|
||||
return dict_apply(input_stats, self.normalize)
|
||||
|
||||
result = {}
|
||||
for key, group in input_stats.items():
|
||||
this_dict = {}
|
||||
for name, value in group.items():
|
||||
this_dict[name] = self.normalize({key: value})[key]
|
||||
result[key] = this_dict
|
||||
return result
|
||||
|
||||
|
||||
class SingleFieldLinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
self.params_dict = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
|
||||
obj = cls()
|
||||
obj.fit(data, **kwargs)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def create_manual(
|
||||
cls,
|
||||
scale: Union[torch.Tensor, np.ndarray],
|
||||
offset: Union[torch.Tensor, np.ndarray],
|
||||
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
|
||||
):
|
||||
def to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.from_numpy(x)
|
||||
x = x.flatten()
|
||||
return x
|
||||
|
||||
# check
|
||||
for x in [offset] + list(input_stats_dict.values()):
|
||||
assert x.shape == scale.shape
|
||||
assert x.dtype == scale.dtype
|
||||
|
||||
params_dict = nn.ParameterDict(
|
||||
{
|
||||
"scale": to_tensor(scale),
|
||||
"offset": to_tensor(offset),
|
||||
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
|
||||
}
|
||||
)
|
||||
return cls(params_dict)
|
||||
|
||||
@classmethod
|
||||
def create_identity(cls, dtype=torch.float32):
|
||||
scale = torch.tensor([1], dtype=dtype)
|
||||
offset = torch.tensor([0], dtype=dtype)
|
||||
input_stats_dict = {
|
||||
"min": torch.tensor([-1], dtype=dtype),
|
||||
"max": torch.tensor([1], dtype=dtype),
|
||||
"mean": torch.tensor([0], dtype=dtype),
|
||||
"std": torch.tensor([1], dtype=dtype),
|
||||
}
|
||||
return cls.create_manual(scale, offset, input_stats_dict)
|
||||
|
||||
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=False)
|
||||
|
||||
def get_input_stats(self):
|
||||
return self.params_dict["input_stats"]
|
||||
|
||||
def get_output_stats(self):
|
||||
return dict_apply(self.params_dict["input_stats"], self.normalize)
|
||||
|
||||
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
|
||||
def _fit(
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
assert mode in ["limits", "gaussian"]
|
||||
assert last_n_dims >= 0
|
||||
assert output_max > output_min
|
||||
|
||||
# convert data to torch and type
|
||||
if isinstance(data, zarr.Array):
|
||||
data = data[:]
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if dtype is not None:
|
||||
data = data.type(dtype)
|
||||
|
||||
# convert shape
|
||||
dim = 1
|
||||
if last_n_dims > 0:
|
||||
dim = np.prod(data.shape[-last_n_dims:])
|
||||
data = data.reshape(-1, dim)
|
||||
|
||||
# compute input stats min max mean std
|
||||
input_min, _ = data.min(axis=0)
|
||||
input_max, _ = data.max(axis=0)
|
||||
input_mean = data.mean(axis=0)
|
||||
input_std = data.std(axis=0)
|
||||
|
||||
# compute scale and offset
|
||||
if mode == "limits":
|
||||
if fit_offset:
|
||||
# unit scale
|
||||
input_range = input_max - input_min
|
||||
ignore_dim = input_range < range_eps
|
||||
input_range[ignore_dim] = output_max - output_min
|
||||
scale = (output_max - output_min) / input_range
|
||||
offset = output_min - scale * input_min
|
||||
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
||||
# ignore dims scaled to mean of output max and min
|
||||
else:
|
||||
# use this when data is pre-zero-centered.
|
||||
assert output_max > 0
|
||||
assert output_min < 0
|
||||
# unit abs
|
||||
output_abs = min(abs(output_min), abs(output_max))
|
||||
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
|
||||
ignore_dim = input_abs < range_eps
|
||||
input_abs[ignore_dim] = output_abs
|
||||
# don't scale constant channels
|
||||
scale = output_abs / input_abs
|
||||
offset = torch.zeros_like(input_mean)
|
||||
elif mode == "gaussian":
|
||||
ignore_dim = input_std < range_eps
|
||||
scale = input_std.clone()
|
||||
scale[ignore_dim] = 1
|
||||
scale = 1 / scale
|
||||
|
||||
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
|
||||
|
||||
# save
|
||||
this_params = nn.ParameterDict(
|
||||
{
|
||||
"scale": scale,
|
||||
"offset": offset,
|
||||
"input_stats": nn.ParameterDict(
|
||||
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
|
||||
),
|
||||
}
|
||||
)
|
||||
for p in this_params.parameters():
|
||||
p.requires_grad_(False)
|
||||
return this_params
|
||||
|
||||
|
||||
def _normalize(x, params, forward=True):
|
||||
assert "scale" in params
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
scale = params["scale"]
|
||||
offset = params["offset"]
|
||||
x = x.to(device=scale.device, dtype=scale.dtype)
|
||||
src_shape = x.shape
|
||||
x = x.reshape(-1, scale.shape[0])
|
||||
x = x * scale + offset if forward else (x - offset) / scale
|
||||
x = x.reshape(src_shape)
|
||||
return x
|
||||
|
||||
|
||||
def test():
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0, atol=1e-3)
|
||||
assert np.allclose(datan.min(), 0.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="gaussian", last_n_dims=0)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
|
||||
assert np.allclose(datan.std(), 1.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
# dict
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
data = {
|
||||
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
|
||||
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data)
|
||||
datan = normalizer.normalize(data)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
state_dict = normalizer.state_dict()
|
||||
n = LinearNormalizer()
|
||||
n.load_state_dict(state_dict)
|
||||
datan = n.normalize(data)
|
||||
dataun = n.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||
@@ -0,0 +1,19 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
971
lerobot/common/policies/diffusion/model/tensor_utils.py
Normal file
971
lerobot/common/policies/diffusion/model/tensor_utils.py
Normal file
@@ -0,0 +1,971 @@
|
||||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert list not in type_func_dict
|
||||
assert tuple not in type_func_dict
|
||||
assert dict not in type_func_dict
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert begin_axis <= end_axis
|
||||
assert begin_axis >= 0
|
||||
assert end_axis < len(x.shape)
|
||||
assert isinstance(target_dims, (tuple, list))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
|
||||
)
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq,
|
||||
{
|
||||
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
|
||||
return outputs
|
||||
Reference in New Issue
Block a user