Refactor env queue, Training diffusion works (Still not converging)
This commit is contained in:
@@ -1,51 +1,11 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
"""Create directory if it does not already exist."""
|
||||
with contextlib.suppress(OSError):
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
return dir_path
|
||||
|
||||
|
||||
def print_run(cfg, reward=None):
|
||||
"""Pretty-printing of run information. Call at start of training."""
|
||||
prefix, color, attrs = " ", "green", ["bold"]
|
||||
|
||||
def limstr(s, maxlen=32):
|
||||
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
|
||||
|
||||
def pprint(k, v):
|
||||
print(
|
||||
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
|
||||
limstr(v),
|
||||
)
|
||||
|
||||
kvs = [
|
||||
("task", cfg.env.task),
|
||||
("offline_steps", f"{cfg.offline_steps}"),
|
||||
("online_steps", f"{cfg.online_steps}"),
|
||||
("action_repeat", f"{cfg.env.action_repeat}"),
|
||||
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
|
||||
# ('actions', cfg.action_dim),
|
||||
# ('experiment', cfg.exp_name),
|
||||
]
|
||||
if reward is not None:
|
||||
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
|
||||
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
|
||||
div = "-" * w
|
||||
print(div)
|
||||
for k, v in kvs:
|
||||
pprint(k, v)
|
||||
print(div)
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
@@ -71,13 +31,12 @@ class Logger:
|
||||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project or not entity
|
||||
if run_offline:
|
||||
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
@@ -134,7 +93,6 @@ class Logger:
|
||||
self.save_buffer(buffer, identifier="buffer")
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
print_run(self._cfg, self._eval[-1][-1])
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
|
||||
Reference in New Issue
Block a user