wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy
FIRST_FRAME = 0
@@ -85,7 +86,7 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
@@ -124,7 +125,7 @@ class TDMPC(nn.Module):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...