Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)

This commit is contained in:
Cadene
2024-02-26 01:10:09 +00:00
parent 5a219fed6e
commit 21670dce90
12 changed files with 306 additions and 443 deletions

View File

@@ -10,10 +10,10 @@ from termcolor import colored
CONSOLE_FORMAT = [
("episode", "E", "int"),
("env_step", "S", "int"),
("step", "S", "int"),
("avg_sum_reward", "RS", "float"),
("avg_max_reward", "RM", "float"),
("pc_success", "S", "float"),
("pc_success", "SR", "float"),
("total_time", "T", "time"),
]
AGENT_METRICS = [
@@ -51,7 +51,9 @@ def print_run(cfg, reward=None):
kvs = [
("task", cfg.env.task),
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
@@ -78,54 +80,6 @@ def cfg_to_group(cfg, return_list=False):
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, root_dir, wandb, render_size=384, fps=15):
self.save_dir = (root_dir / "eval_video") if root_dir else None
self._wandb = wandb
self.render_size = render_size
self.fps = fps
self.frames = []
self.enabled = False
self.camera_id = 0
def init(self, env, enabled=True):
self.frames = []
self.enabled = self.save_dir and self._wandb and enabled
try:
env_name = env.unwrapped.spec.id
except:
env_name = ""
if "maze2d" in env_name:
self.camera_id = -1
elif "quadruped" in env_name:
self.camera_id = 2
self.record(env)
def record(self, env):
if self.enabled:
frame = env.render(
mode="rgb_array",
height=self.render_size,
width=self.render_size,
camera_id=self.camera_id,
)
self.frames.append(frame)
def save(self, step):
if self.enabled:
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
self._wandb.log(
{
"eval_video": self._wandb.Video(
frames, fps=self.env.fps, format="mp4"
)
},
step=step,
)
class Logger(object):
"""Primary logger object. Logs either locally or using wandb."""
@@ -170,15 +124,6 @@ class Logger(object):
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb
self._video = (
VideoRecorder(self._log_dir, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
def save_model(self, agent, identifier):
if self._save_model:
@@ -214,12 +159,12 @@ class Logger(object):
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key + ":", "grey")} {int(value):,}'
return f'{colored(key + ":", "yellow")} {int(value):,}'
elif ty == "float":
return f'{colored(key + ":", "grey")} {value:.01f}'
return f'{colored(key + ":", "yellow")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key + ":", "grey")} {value}'
return f'{colored(key + ":", "yellow")} {value}'
else:
raise f"invalid log format type: {ty}"
@@ -234,10 +179,9 @@ class Logger(object):
assert category in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
self._wandb.log({category + "/" + k: v}, step=d["step"])
if category == "eval":
# keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"]
keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None