Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -13,7 +13,7 @@ class TOLD(nn.Module):
def __init__(self, cfg):
super().__init__()
action_dim = 4
action_dim = cfg.action_dim
self.cfg = cfg
self._encoder = h.enc(cfg)
@@ -82,7 +82,7 @@ class TDMPC(nn.Module):
def __init__(self, cfg):
super().__init__()
self.action_dim = 4
self.action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device("cuda")