This commit is contained in:
Thomas Wolf
2024-06-18 13:44:28 +02:00
parent c9150c361b
commit 1cd7ca71a1

View File

@@ -162,6 +162,9 @@ class Normalize(nn.Module):
output_batch[key] = output_batch[key] * 2 - 1
else:
raise ValueError(mode)
for key in batch:
if key not in output_batch:
output_batch[key] = batch[key]
return output_batch
@@ -231,4 +234,7 @@ class Unnormalize(nn.Module):
output_batch[key] = output_batch[key] * (max - min) + min
else:
raise ValueError(mode)
for key in batch:
if key not in output_batch:
output_batch[key] = batch[key]
return output_batch