Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@@ -1,51 +1,11 @@
import contextlib
import logging
import os
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.env.task),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
@@ -71,13 +31,12 @@ class Logger:
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
print_run(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project or not entity
if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
@@ -134,7 +93,6 @@ class Logger:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}