Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)
This commit is contained in:
@@ -130,7 +130,7 @@ class Flatten(nn.Module):
|
||||
def enc(cfg):
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (4,),
|
||||
"state": (cfg.state_dim,),
|
||||
}
|
||||
|
||||
"""Returns a TOLD encoder."""
|
||||
@@ -209,7 +209,7 @@ def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
|
||||
|
||||
|
||||
def q(cfg):
|
||||
action_dim = 4
|
||||
action_dim = cfg.action_dim
|
||||
"""Returns a Q-function that uses Layer Normalization."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
||||
@@ -331,7 +331,7 @@ class Episode(object):
|
||||
"""Storage object for a single episode."""
|
||||
|
||||
def __init__(self, cfg, init_obs):
|
||||
action_dim = 4
|
||||
action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
@@ -447,8 +447,8 @@ class ReplayBuffer:
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, dataset=None):
|
||||
action_dim = 4
|
||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
|
||||
action_dim = cfg.action_dim
|
||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
|
||||
Reference in New Issue
Block a user