Integrate diffusion policy

This commit is contained in:
Simon Alibert
2024-03-10 15:31:17 +01:00
parent 59397fb44a
commit 302b78962c
16 changed files with 2850 additions and 13 deletions

View File

@@ -0,0 +1,15 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype