66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import torch
|
|
|
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
|
|
|
|
|
class LowdimMaskGenerator(ModuleAttrMixin):
|
|
def __init__(
|
|
self,
|
|
action_dim,
|
|
obs_dim,
|
|
# obs mask setup
|
|
max_n_obs_steps=2,
|
|
fix_obs_steps=True,
|
|
# action mask
|
|
action_visible=False,
|
|
):
|
|
super().__init__()
|
|
self.action_dim = action_dim
|
|
self.obs_dim = obs_dim
|
|
self.max_n_obs_steps = max_n_obs_steps
|
|
self.fix_obs_steps = fix_obs_steps
|
|
self.action_visible = action_visible
|
|
|
|
@torch.no_grad()
|
|
def forward(self, shape, seed=None):
|
|
device = self.device
|
|
B, T, D = shape # noqa: N806
|
|
assert (self.action_dim + self.obs_dim) == D
|
|
|
|
# create all tensors on this device
|
|
rng = torch.Generator(device=device)
|
|
if seed is not None:
|
|
rng = rng.manual_seed(seed)
|
|
|
|
# generate dim mask
|
|
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
|
is_action_dim = dim_mask.clone()
|
|
is_action_dim[..., : self.action_dim] = True
|
|
is_obs_dim = ~is_action_dim
|
|
|
|
# generate obs mask
|
|
if self.fix_obs_steps:
|
|
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
|
else:
|
|
obs_steps = torch.randint(
|
|
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
|
)
|
|
|
|
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
|
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
|
obs_mask = obs_mask & is_obs_dim
|
|
|
|
# generate action mask
|
|
if self.action_visible:
|
|
action_steps = torch.maximum(
|
|
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
|
)
|
|
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
|
action_mask = action_mask & is_action_dim
|
|
|
|
mask = obs_mask
|
|
if self.action_visible:
|
|
mask = mask | action_mask
|
|
|
|
return mask
|