Hardcoded some normalization parameters. TODO refactor

Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-13 14:27:14 +01:00
parent 98c6557869
commit 5195f40fd3
6 changed files with 36 additions and 8 deletions

View File

@@ -195,6 +195,7 @@ class ReplayBuffer:
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
@@ -229,6 +230,12 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(device)
if action_mask is not None:
if data["action"].dim() == 1:
data["action"] = data["action"][action_mask]
else:
data["action"] = data["action"][:, action_mask]
replay_buffer.add(
state=data["state"],
action=data["action"],