WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,14 +1,12 @@
import logging
from itertools import cycle
from pathlib import Path
import hydra
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
@@ -34,7 +32,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
def log_train_info(logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -44,9 +42,9 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -73,7 +71,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
@@ -81,9 +79,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -124,30 +122,30 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("make_dataset")
dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
if cfg.policy.balanced_sampling:
logging.info("make online_buffer")
num_traj_per_batch = cfg.policy.batch_size
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
online_sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.policy.per_alpha,
beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch,
strict_length=True,
)
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
transform=offline_buffer.transform,
)
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -155,8 +153,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
@@ -165,8 +161,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
logging.info(f"{offline_buffer.num_episodes=}")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -176,14 +172,15 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
@@ -196,14 +193,29 @@ def train(cfg: dict, out_dir=None, job_name=None):
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
train_info = policy.update(offline_buffer, step)
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
@@ -211,7 +223,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
@@ -221,7 +233,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length,
policy=td_policy,
policy=policy,
auto_cast_to_device=True,
)
@@ -242,7 +254,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
@@ -257,13 +269,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
for _ in range(cfg.policy.utd):
train_info = policy.update(
online_buffer,
# online_buffer,
step,
demo_buffer=demo_buffer,
)
if step % cfg.log_freq == 0:
train_info.update(rollout_info)
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.