diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index bcbb68a0..9cc94b92 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -338,6 +338,7 @@ class NormalizeBuffer(nn.Module): _initialize_stats_buffers(self, features, norm_map, stats) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) for key, ft in self.features.items(): if key not in batch: continue @@ -386,6 +387,7 @@ class UnnormalizeBuffer(nn.Module): _initialize_stats_buffers(self, features, norm_map, stats) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # batch = dict(batch) for key, ft in self.features.items(): if key not in batch: continue