[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by Michel Aractingi
parent bb69cb3c8c
commit 85fe8a3f4e
79 changed files with 2800 additions and 794 deletions

View File

@@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset()
@@ -153,7 +155,8 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@@ -163,7 +166,12 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
@@ -217,7 +225,9 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@@ -233,7 +243,9 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@@ -241,19 +253,34 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -319,7 +346,9 @@ class ACT(nn.Module):
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@@ -327,20 +356,28 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -386,7 +423,9 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@@ -424,7 +463,9 @@ class ACT(nn.Module):
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@@ -465,20 +506,24 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
)
# Camera observation features and positional embeddings.
@@ -535,12 +580,21 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -551,7 +605,9 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -566,7 +622,9 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@@ -591,7 +649,9 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@@ -603,7 +663,10 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)
@@ -613,8 +676,12 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -655,7 +722,9 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -692,9 +761,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@@ -739,7 +813,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@@ -747,9 +823,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed

View File

@@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -189,7 +193,9 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@@ -209,7 +215,10 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
@@ -232,7 +241,9 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
return sample
@@ -244,27 +255,39 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
global_cond_feats.append(img_features)
@@ -350,7 +373,9 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none")
@@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@@ -528,7 +561,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -543,7 +580,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module
@@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
)
# The FiLM conditioning dimension.
@@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
]
)
@@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@@ -7,7 +7,9 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@@ -15,7 +17,10 @@ class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
@@ -43,12 +48,14 @@ class Classifier(
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel
from transformers import AutoModel
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
@@ -74,7 +81,9 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
@@ -94,14 +103,19 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
self.classifier_head = self.classifier_head.to(self.config.device)
@@ -127,7 +141,10 @@ class Classifier(
return features
else: # Transformer models
outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
@@ -143,7 +160,9 @@ class Classifier(
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:

View File

@@ -59,7 +59,9 @@ class SACPolicy(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, input_normalization_params
config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
@@ -90,7 +92,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@@ -104,7 +107,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@@ -120,13 +124,17 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -153,7 +161,11 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
@@ -173,21 +185,37 @@ class SACPolicy(
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@@ -204,7 +232,12 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@@ -219,20 +252,31 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@@ -259,7 +303,11 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers
for i in range(1, len(hidden_dims)):
@@ -270,7 +318,9 @@ class MLP(nn.Module):
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
@@ -445,7 +499,11 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs
outputs = self.network(obs_enc)
@@ -454,11 +512,15 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
@@ -471,7 +533,9 @@ class Policy(nn.Module):
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
@@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
@@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
@@ -673,7 +756,7 @@ def orthogonal_init():
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
super().__init__()
def forward(self, x):
return x
@@ -701,7 +784,9 @@ class Ensemble(nn.Module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
return torch.vmap(self._call, (0, None), randomness="different")(
self.params, *args, **kwargs
)
def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr
@@ -710,7 +795,9 @@ class Ensemble(nn.Module):
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
@@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params

View File

@@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
}
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
@@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
@@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions
@@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"],
)
@@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module):
@@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module):
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
)
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy):
)
if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -168,7 +174,9 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
)
return loss, {
"n_different_codes": n_different_codes,
@@ -225,7 +233,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -339,7 +349,12 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@@ -354,7 +369,11 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@@ -366,13 +385,19 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@@ -391,7 +416,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
)
else:
features = features[:, historical_act_pred_index]
@@ -399,13 +428,15 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
return action_head_output, loss
@@ -440,7 +471,9 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@@ -467,7 +500,10 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@@ -514,7 +550,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1,
)
)
@@ -522,19 +564,29 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@@ -554,9 +606,17 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@@ -605,7 +665,9 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss.
@@ -628,8 +690,12 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@@ -643,7 +709,9 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@@ -668,7 +736,9 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -689,7 +759,9 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@@ -730,7 +802,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@@ -743,7 +817,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -758,7 +836,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module

View File

@@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@@ -189,12 +197,16 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@@ -208,11 +220,17 @@ class GPT(nn.Module):
)
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@@ -237,7 +255,9 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@@ -270,7 +290,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
@@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers
@@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
)
for _ in range(num_quantizers)
]
@@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
)
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses)
@@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps
self.commitment_weight = commitment_weight
@@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = {
"dim": codebook_dim,
@@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
if self.training:
# determine code to use for commitment loss
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize)
@@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss
@@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0):
@@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
else:
samples_per_rank = torch.empty_like(all_num_samples)
@@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
)
x = x.float()
@@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed)
@@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d")