Shallow copy
This commit is contained in:
@@ -338,6 +338,7 @@ class NormalizeBuffer(nn.Module):
|
|||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch = dict(batch)
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
continue
|
continue
|
||||||
@@ -386,6 +387,7 @@ class UnnormalizeBuffer(nn.Module):
|
|||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
# batch = dict(batch)
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user