Small fix, Refactor diffusion, Diffusion runs (TODO: remove normalization in diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 17:04:39 +00:00
parent 45b4ecb727
commit 80785f8d0e
6 changed files with 449 additions and 10 deletions

View File

@@ -147,7 +147,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, transform=offline_buffer._transform)
logging.info("make_policy")
policy = make_policy(cfg, transform=offline_buffer._transform)
policy = make_policy(cfg)
td_policy = TensorDictModule(
policy,