Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)
This commit is contained in:
@@ -13,7 +13,7 @@ class TOLD(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
action_dim = 4
|
||||
action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self._encoder = h.enc(cfg)
|
||||
@@ -82,7 +82,7 @@ class TDMPC(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.action_dim = 4
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
Reference in New Issue
Block a user