wip: still needs batch logic for act and tdmp
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -85,7 +86,7 @@ class TOLD(nn.Module):
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(nn.Module):
|
||||
class TDMPC(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
@@ -124,7 +125,7 @@ class TDMPC(nn.Module):
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
|
||||
Reference in New Issue
Block a user