Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -130,7 +130,7 @@ class Flatten(nn.Module):
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
"state": (cfg.state_dim,),
}
"""Returns a TOLD encoder."""
@@ -209,7 +209,7 @@ def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
def q(cfg):
action_dim = 4
action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
@@ -331,7 +331,7 @@ class Episode(object):
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
action_dim = 4
action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
@@ -447,8 +447,8 @@ class ReplayBuffer:
"""
def __init__(self, cfg, dataset=None):
action_dim = 4
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
action_dim = cfg.action_dim
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)