fix(lerobot/common/policies): remove lint warnings/errors

This commit is contained in:
Steven Palma
2025-03-07 15:06:39 +01:00
parent 652fedf69c
commit 5c6f2d2cd0
14 changed files with 147 additions and 140 deletions

View File

@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.ensembled_actions = None
self.ensembled_actions_count = None
self.reset() self.reset()
def reset(self): def reset(self):

View File

@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
raise ValueError(f"Unsupported noise scheduler type {name}") raise ValueError(f"Unsupported noise scheduler type {name}")
# TODO(Steven): Missing forward() implementation
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
) )
if config.num_inference_steps is None: if config.num_inference_steps is None:
# TODO(Steven): Consider type check?
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else: else:
self.num_inference_steps = config.num_inference_steps self.num_inference_steps = config.num_inference_steps
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
# Sample a random noising timestep for each item in the batch. # Sample a random noising timestep for each item in the batch.
timesteps = torch.randint( timesteps = torch.randint(
low=0, low=0,
high=self.noise_scheduler.config.num_train_timesteps, high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
size=(trajectory.shape[0],), size=(trajectory.shape[0],),
device=trajectory.device, device=trajectory.device,
).long() ).long()

View File

@@ -69,12 +69,12 @@ def create_stats_buffers(
} }
) )
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"min": nn.Parameter(min, requires_grad=False), "min": nn.Parameter(min_norm, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False), "max": nn.Parameter(max_norm, requires_grad=False),
} }
) )
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min_norm = buffer["min"]
max = buffer["max"] max_norm = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max") assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8) batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
# normalize to [-1, 1] # normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1 batch[key] = batch[key] * 2 - 1
else: else:
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min_norm = buffer["min"]
max = buffer["max"] max_norm = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max") assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max_norm - min_norm) + min_norm
else: else:
raise ValueError(norm_mode) raise ValueError(norm_mode)
return batch return batch

View File

@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
super().__post_init__() super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs? # TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"The chunk size is the upper bound for the number of action steps per model invocation. Got "

View File

@@ -55,7 +55,7 @@ def main():
with open(save_dir / "noise.pkl", "rb") as f: with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f) noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f: with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
norm_stats = json.load(f) norm_stats = json.load(f)
# Override stats # Override stats

View File

@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()} return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT # Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params # process projection params
@@ -432,6 +432,6 @@ if __name__ == "__main__":
convert_pi0_checkpoint( convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,
precision=args.precision, precision=args.precision,
tokenizer_id=args.tokenizer_hub_id, _tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path, output_path=args.output_path,
) )

View File

@@ -16,6 +16,7 @@ import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from packaging.version import Version from packaging.version import Version
# TODO(Steven): Consider settings this a dependency constraint
if Version(torch.__version__) > Version("2.5.0"): if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards # Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
@@ -121,7 +122,7 @@ def flex_attention_forward(
) )
# mask is applied inside the kernel, ideally more efficiently than score_mod. # mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention( attn_output, _attention_weights = flex_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,

View File

@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"

View File

@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
for param in self.model_target.parameters(): for param in self.model_target.parameters():
param.requires_grad = False param.requires_grad = False
self._queues = None
self._prev_mean: torch.Tensor | None = None
self.reset() self.reset()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step. # CEM for the next step.
self._prev_mean: torch.Tensor | None = None self._prev_mean = None
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
# TODO(Steven): forward implementation missing
class TDMPCTOLD(nn.Module): class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""

View File

@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
"""Input validation (not exhaustive).""" # Input validation (not exhaustive).
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."

View File

@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
self.vqbet = VQBeTModel(config) self.vqbet = VQBeTModel(config)
self._queues = None
self.reset() self.reset()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
) )
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape NT, _G, choices = cbet_probs.shape
sampled_centers = einops.rearrange( sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
"(NT G) 1 -> NT G", "(NT G) 1 -> NT G",
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
"decoded_action": decoded_action, "decoded_action": decoded_action,
} }
def loss_fn(self, pred, target, **kwargs): def loss_fn(self, pred, target, **_kwargs):
""" """
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss. for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions. # Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action. # First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad(): with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G _state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss. # Now we can compute the loss.
@@ -762,6 +764,7 @@ def _replace_submodules(
return root_module return root_module
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
class VqVae(nn.Module): class VqVae(nn.Module):
def __init__( def __init__(
self, self,
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
self.gamma = gamma self.gamma = gamma
self.size_average = size_average self.size_average = size_average
def forward(self, input, target): def forward(self, forward_input, target):
if len(input.shape) == 3: if len(forward_input.shape) == 3:
N, T, _ = input.shape N, T, _ = forward_input.shape
logpt = F.log_softmax(input, dim=-1) logpt = F.log_softmax(forward_input, dim=-1)
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T) logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
elif len(input.shape) == 2: elif len(forward_input.shape) == 2:
logpt = F.log_softmax(input, dim=-1) logpt = F.log_softmax(forward_input, dim=-1)
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1) logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
pt = logpt.exp() pt = logpt.exp()

View File

@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
# ruff: noqa: N806 # ruff: noqa: N806
""" # This file is part of a VQ-BeT that utilizes code from the following repositories:
This file is part of a VQ-BeT that utilizes code from the following repositories: #
# - Vector Quantize PyTorch code is licensed under the MIT License:
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
#
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
# Original source: https://github.com/karpathy/nanoGPT
#
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
- Vector Quantize PyTorch code is licensed under the MIT License: # This is a part for nanoGPT that utilizes code from the following repository:
Original source: https://github.com/lucidrains/vector-quantize-pytorch #
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. # Original source: https://github.com/karpathy/nanoGPT
Original source: https://github.com/karpathy/nanoGPT #
# - The nanoGPT code is licensed under the MIT License:
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below. #
""" # MIT License
#
""" # Copyright (c) 2022 Andrej Karpathy
This is a part for nanoGPT that utilizes code from the following repository: #
# Permission is hereby granted, free of charge, to any person obtaining a copy
- Andrej Karpathy's nanoGPT implementation in PyTorch. # of this software and associated documentation files (the "Software"), to deal
Original source: https://github.com/karpathy/nanoGPT # in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- The nanoGPT code is licensed under the MIT License: # copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
MIT License #
# The above copyright notice and this permission notice shall be included in all
Copyright (c) 2022 Andrej Karpathy # copies or substantial portions of the Software.
#
Permission is hereby granted, free of charge, to any person obtaining a copy # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
of this software and associated documentation files (the "Software"), to deal # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
in the Software without restriction, including without limitation the rights # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
copies of the Software, and to permit persons to whom the Software is # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
furnished to do so, subject to the following conditions: # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
The above copyright notice and this permission notice shall be included in all #
copies or substantial portions of the Software. # - We've made some changes to the original code to adapt it to our needs.
#
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # Changed variable names:
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # - n_head -> gpt_n_head
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # - n_embd -> gpt_hidden_dim
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # - block_size -> gpt_block_size
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # - n_layer -> gpt_n_layer
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE #
SOFTWARE. #
# class GPT(nn.Module):
- We've made some changes to the original code to adapt it to our needs. # - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
Changed variable names: # - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
- n_head -> gpt_n_head
- n_embd -> gpt_hidden_dim
- block_size -> gpt_block_size
- n_layer -> gpt_n_layer
class GPT(nn.Module):
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
"""
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
@@ -200,9 +195,9 @@ class GPT(nn.Module):
n_params = sum(p.numel() for p in self.parameters()) n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: {:.2f}M".format(n_params / 1e6)) print("number of parameters: {:.2f}M".format(n_params / 1e6))
def forward(self, input, targets=None): def forward(self, forward_input):
device = input.device device = forward_input.device
b, t, d = input.size() _, t, _ = forward_input.size()
assert t <= self.config.gpt_block_size, ( assert t <= self.config.gpt_block_size, (
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
) )
@@ -211,7 +206,7 @@ class GPT(nn.Module):
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself # forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h: for block in self.transformer.h:
@@ -285,51 +280,48 @@ class GPT(nn.Module):
return decay, no_decay return decay, no_decay
""" # This file is a part for Residual Vector Quantization that utilizes code from the following repository:
This file is a part for Residual Vector Quantization that utilizes code from the following repository: #
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
- Phil Wang's vector-quantize-pytorch implementation in PyTorch. # Original source: https://github.com/lucidrains/vector-quantize-pytorch
Original source: https://github.com/lucidrains/vector-quantize-pytorch #
# - The vector-quantize-pytorch code is licensed under the MIT License:
- The vector-quantize-pytorch code is licensed under the MIT License: #
# MIT License
MIT License #
# Copyright (c) 2020 Phil Wang
Copyright (c) 2020 Phil Wang #
# Permission is hereby granted, free of charge, to any person obtaining a copy
Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal
of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights
in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is
copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions:
furnished to do so, subject to the following conditions: #
# The above copyright notice and this permission notice shall be included in all
The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software.
copies or substantial portions of the Software. #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE.
SOFTWARE. #
# - We've made some changes to the original code to adapt it to our needs.
- We've made some changes to the original code to adapt it to our needs. #
# class ResidualVQ(nn.Module):
class ResidualVQ(nn.Module): # - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method: # This enables the user to save an indicator whether the codebook is frozen or not.
This enables the user to save an indicator whether the codebook is frozen or not. # - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: # This is to make the function name more descriptive.
This is to make the function name more descriptive. #
# class VectorQuantize(nn.Module):
class VectorQuantize(nn.Module): # - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method: # These parameters are not used in the code.
These parameters are not used in the code. # - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: # This is to make the function name more descriptive.
This is to make the function name more descriptive.
"""
class ResidualVQ(nn.Module): class ResidualVQ(nn.Module):
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
null_indices = None
null_loss = None
# sample a layer index at which to dropout further residual quantization # sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss # also prepare null indices and loss
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
return quantize, embed_ind, loss return quantize, embed_ind, loss
def noop(*args, **kwargs): def noop(*_args, **_kwargs):
pass pass