backup wip
This commit is contained in:
@@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
@@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.n_action_steps = n_action_steps # needed for the parent class
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
@@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user