fix(lerobot/common/policies): remove lint warnings/errors
This commit is contained in:
@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||||
|
self.ensembled_actions = None
|
||||||
|
self.ensembled_actions_count = None
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing forward() implementation
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.num_inference_steps is None:
|
if config.num_inference_steps is None:
|
||||||
|
# TODO(Steven): Consider type check?
|
||||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||||
else:
|
else:
|
||||||
self.num_inference_steps = config.num_inference_steps
|
self.num_inference_steps = config.num_inference_steps
|
||||||
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Sample a random noising timestep for each item in the batch.
|
# Sample a random noising timestep for each item in the batch.
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.noise_scheduler.config.num_train_timesteps,
|
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||||
size=(trajectory.shape[0],),
|
size=(trajectory.shape[0],),
|
||||||
device=trajectory.device,
|
device=trajectory.device,
|
||||||
).long()
|
).long()
|
||||||
|
|||||||
@@ -69,12 +69,12 @@ def create_stats_buffers(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"min": nn.Parameter(min, requires_grad=False),
|
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||||
"max": nn.Parameter(max, requires_grad=False),
|
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
batch[key] = batch[key] * 2 - 1
|
batch[key] = batch[key] * 2 - 1
|
||||||
else:
|
else:
|
||||||
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||||
else:
|
else:
|
||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def main():
|
|||||||
with open(save_dir / "noise.pkl", "rb") as f:
|
with open(save_dir / "noise.pkl", "rb") as f:
|
||||||
noise = pickle.load(f)
|
noise = pickle.load(f)
|
||||||
|
|
||||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||||
norm_stats = json.load(f)
|
norm_stats = json.load(f)
|
||||||
|
|
||||||
# Override stats
|
# Override stats
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
|||||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||||
|
|
||||||
|
|
||||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||||
# Break down orbax ckpts - they are in OCDBT
|
# Break down orbax ckpts - they are in OCDBT
|
||||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||||
# process projection params
|
# process projection params
|
||||||
@@ -432,6 +432,6 @@ if __name__ == "__main__":
|
|||||||
convert_pi0_checkpoint(
|
convert_pi0_checkpoint(
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
tokenizer_id=args.tokenizer_hub_id,
|
_tokenizer_id=args.tokenizer_hub_id,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import torch
|
|||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
# TODO(Steven): Consider settings this a dependency constraint
|
||||||
if Version(torch.__version__) > Version("2.5.0"):
|
if Version(torch.__version__) > Version("2.5.0"):
|
||||||
# Ffex attention is only available from torch 2.5 onwards
|
# Ffex attention is only available from torch 2.5 onwards
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
@@ -121,7 +122,7 @@ def flex_attention_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||||
attn_output, attention_weights = flex_attention(
|
attn_output, _attention_weights = flex_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
for param in self.model_target.parameters():
|
for param in self.model_target.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||||
# CEM for the next step.
|
# CEM for the next step.
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): forward implementation missing
|
||||||
class TDMPCTOLD(nn.Module):
|
class TDMPCTOLD(nn.Module):
|
||||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
self.vqbet = VQBeTModel(config)
|
self.vqbet = VQBeTModel(config)
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
|
|||||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||||
)
|
)
|
||||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||||
NT, G, choices = cbet_probs.shape
|
NT, _G, choices = cbet_probs.shape
|
||||||
sampled_centers = einops.rearrange(
|
sampled_centers = einops.rearrange(
|
||||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||||
"(NT G) 1 -> NT G",
|
"(NT G) 1 -> NT G",
|
||||||
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
|
|||||||
"decoded_action": decoded_action,
|
"decoded_action": decoded_action,
|
||||||
}
|
}
|
||||||
|
|
||||||
def loss_fn(self, pred, target, **kwargs):
|
def loss_fn(self, pred, target, **_kwargs):
|
||||||
"""
|
"""
|
||||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||||
|
|
||||||
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
|
|||||||
# Figure out the loss for the actions.
|
# Figure out the loss for the actions.
|
||||||
# First, we need to find the closest cluster center for each ground truth action.
|
# First, we need to find the closest cluster center for each ground truth action.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||||
|
|
||||||
# Now we can compute the loss.
|
# Now we can compute the loss.
|
||||||
|
|
||||||
@@ -762,6 +764,7 @@ def _replace_submodules(
|
|||||||
return root_module
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||||
class VqVae(nn.Module):
|
class VqVae(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
|
|||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.size_average = size_average
|
self.size_average = size_average
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, forward_input, target):
|
||||||
if len(input.shape) == 3:
|
if len(forward_input.shape) == 3:
|
||||||
N, T, _ = input.shape
|
N, T, _ = forward_input.shape
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||||
elif len(input.shape) == 2:
|
elif len(forward_input.shape) == 2:
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||||
pt = logpt.exp()
|
pt = logpt.exp()
|
||||||
|
|
||||||
|
|||||||
@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|||||||
|
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
|
|
||||||
"""
|
# This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
#
|
||||||
|
# - Vector Quantize PyTorch code is licensed under the MIT License:
|
||||||
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
|
#
|
||||||
|
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
|
#
|
||||||
|
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||||
|
|
||||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
# This is a part for nanoGPT that utilizes code from the following repository:
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
#
|
||||||
|
# - The nanoGPT code is licensed under the MIT License:
|
||||||
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
#
|
||||||
"""
|
# MIT License
|
||||||
|
#
|
||||||
"""
|
# Copyright (c) 2022 Andrej Karpathy
|
||||||
This is a part for nanoGPT that utilizes code from the following repository:
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
- The nanoGPT code is licensed under the MIT License:
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
MIT License
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
Copyright (c) 2022 Andrej Karpathy
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
in the Software without restriction, including without limitation the rights
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
furnished to do so, subject to the following conditions:
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
The above copyright notice and this permission notice shall be included in all
|
#
|
||||||
copies or substantial portions of the Software.
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
|
#
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# Changed variable names:
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# - n_head -> gpt_n_head
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# - n_embd -> gpt_hidden_dim
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# - block_size -> gpt_block_size
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# - n_layer -> gpt_n_layer
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
#
|
||||||
SOFTWARE.
|
#
|
||||||
|
# class GPT(nn.Module):
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
# - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||||
|
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||||
Changed variable names:
|
# - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||||
- n_head -> gpt_n_head
|
|
||||||
- n_embd -> gpt_hidden_dim
|
|
||||||
- block_size -> gpt_block_size
|
|
||||||
- n_layer -> gpt_n_layer
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
|
||||||
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
|
||||||
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
@@ -200,9 +195,9 @@ class GPT(nn.Module):
|
|||||||
n_params = sum(p.numel() for p in self.parameters())
|
n_params = sum(p.numel() for p in self.parameters())
|
||||||
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
||||||
|
|
||||||
def forward(self, input, targets=None):
|
def forward(self, forward_input):
|
||||||
device = input.device
|
device = forward_input.device
|
||||||
b, t, d = input.size()
|
_, t, _ = forward_input.size()
|
||||||
assert t <= self.config.gpt_block_size, (
|
assert t <= self.config.gpt_block_size, (
|
||||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||||
)
|
)
|
||||||
@@ -211,7 +206,7 @@ class GPT(nn.Module):
|
|||||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||||
|
|
||||||
# forward the GPT model itself
|
# forward the GPT model itself
|
||||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
@@ -285,51 +280,48 @@ class GPT(nn.Module):
|
|||||||
return decay, no_decay
|
return decay, no_decay
|
||||||
|
|
||||||
|
|
||||||
"""
|
# This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
#
|
||||||
|
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - The vector-quantize-pytorch code is licensed under the MIT License:
|
||||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
#
|
||||||
|
# MIT License
|
||||||
MIT License
|
#
|
||||||
|
# Copyright (c) 2020 Phil Wang
|
||||||
Copyright (c) 2020 Phil Wang
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# in the Software without restriction, including without limitation the rights
|
||||||
in the Software without restriction, including without limitation the rights
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# furnished to do so, subject to the following conditions:
|
||||||
furnished to do so, subject to the following conditions:
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
The above copyright notice and this permission notice shall be included in all
|
# copies or substantial portions of the Software.
|
||||||
copies or substantial portions of the Software.
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# SOFTWARE.
|
||||||
SOFTWARE.
|
#
|
||||||
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
#
|
||||||
|
# class ResidualVQ(nn.Module):
|
||||||
class ResidualVQ(nn.Module):
|
# - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||||
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
# This enables the user to save an indicator whether the codebook is frozen or not.
|
||||||
This enables the user to save an indicator whether the codebook is frozen or not.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
#
|
||||||
|
# class VectorQuantize(nn.Module):
|
||||||
class VectorQuantize(nn.Module):
|
# - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||||
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
# These parameters are not used in the code.
|
||||||
These parameters are not used in the code.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualVQ(nn.Module):
|
class ResidualVQ(nn.Module):
|
||||||
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
|
|||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||||
|
|
||||||
|
null_indices = None
|
||||||
|
null_loss = None
|
||||||
|
|
||||||
# sample a layer index at which to dropout further residual quantization
|
# sample a layer index at which to dropout further residual quantization
|
||||||
# also prepare null indices and loss
|
# also prepare null indices and loss
|
||||||
|
|
||||||
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
|
|||||||
return quantize, embed_ind, loss
|
return quantize, embed_ind, loss
|
||||||
|
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user