backup wip
This commit is contained in:
@@ -42,7 +42,7 @@ def kl_divergence(mu, logvar):
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
@@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user