wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -2,10 +2,10 @@ import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
@@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(nn.Module):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
@@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count