Integrate diffusion policy
This commit is contained in:
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
|
||||
|
||||
class LowdimMaskGenerator(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
# obs mask setup
|
||||
max_n_obs_steps=2,
|
||||
fix_obs_steps=True,
|
||||
# action mask
|
||||
action_visible=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.max_n_obs_steps = max_n_obs_steps
|
||||
self.fix_obs_steps = fix_obs_steps
|
||||
self.action_visible = action_visible
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, shape, seed=None):
|
||||
device = self.device
|
||||
B, T, D = shape # noqa: N806
|
||||
assert (self.action_dim + self.obs_dim) == D
|
||||
|
||||
# create all tensors on this device
|
||||
rng = torch.Generator(device=device)
|
||||
if seed is not None:
|
||||
rng = rng.manual_seed(seed)
|
||||
|
||||
# generate dim mask
|
||||
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
||||
is_action_dim = dim_mask.clone()
|
||||
is_action_dim[..., : self.action_dim] = True
|
||||
is_obs_dim = ~is_action_dim
|
||||
|
||||
# generate obs mask
|
||||
if self.fix_obs_steps:
|
||||
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
||||
else:
|
||||
obs_steps = torch.randint(
|
||||
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
||||
)
|
||||
|
||||
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
||||
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
obs_mask = obs_mask & is_obs_dim
|
||||
|
||||
# generate action mask
|
||||
if self.action_visible:
|
||||
action_steps = torch.maximum(
|
||||
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
||||
)
|
||||
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
action_mask = action_mask & is_action_dim
|
||||
|
||||
mask = obs_mask
|
||||
if self.action_visible:
|
||||
mask = mask | action_mask
|
||||
|
||||
return mask
|
||||
Reference in New Issue
Block a user