forked from tangger/lerobot
Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -195,6 +195,7 @@ class ReplayBuffer:
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
@@ -229,6 +230,12 @@ class ReplayBuffer:
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
if action_mask is not None:
|
||||
if data["action"].dim() == 1:
|
||||
data["action"] = data["action"][action_mask]
|
||||
else:
|
||||
data["action"] = data["action"][:, action_mask]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
|
||||
Reference in New Issue
Block a user