backup wip

This commit is contained in:
Alexander Soare
2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step
**kwargs,
):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
@@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
**cfg_obs_encoder,
)
self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
)
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count