diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index a711dcaa8..8f9096a6d 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -162,6 +162,9 @@ class Normalize(nn.Module): output_batch[key] = output_batch[key] * 2 - 1 else: raise ValueError(mode) + for key in batch: + if key not in output_batch: + output_batch[key] = batch[key] return output_batch @@ -231,4 +234,7 @@ class Unnormalize(nn.Module): output_batch[key] = output_batch[key] * (max - min) + min else: raise ValueError(mode) + for key in batch: + if key not in output_batch: + output_batch[key] = batch[key] return output_batch