forked from tangger/lerobot
wip: still needs batch logic for act and tdmp
This commit is contained in:
@@ -2,10 +2,10 @@ import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
@@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user