Shallow copy

This commit is contained in:
AdilZouitine
2025-05-16 18:38:15 +02:00
parent adb1d08cc2
commit b166296ba5

View File

@@ -338,6 +338,7 @@ class NormalizeBuffer(nn.Module):
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
@@ -386,6 +387,7 @@ class UnnormalizeBuffer(nn.Module):
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue