Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
@@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
||||
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||
"""
|
||||
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
||||
|
||||
Parameters:
|
||||
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
||||
- n_on (int): Number of online samples.
|
||||
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
||||
|
||||
The total weight of offline samples is n_off * 1.0.
|
||||
The total weight of offline samples is n_on * w.
|
||||
The total combined weight of all samples is n_off + n_on * w.
|
||||
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
||||
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
||||
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
"""
|
||||
assert 0.0 <= pc_on <= 1.0
|
||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
data_dict = episodes["data_dict"]
|
||||
data_ids_per_episode = episodes["data_ids_per_episode"]
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.data_dict = data_dict
|
||||
online_dataset.data_ids_per_episode = data_ids_per_episode
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
|
||||
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||
data_dict["episode"] += start_episode
|
||||
data_dict["index"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
for key in data_dict:
|
||||
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
|
||||
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
|
||||
for ep_id in data_ids_per_episode:
|
||||
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
|
||||
data_ids_per_episode[ep_id] + start_index
|
||||
)
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
||||
len_online = len(online_dataset)
|
||||
len_offline = len(concat_dataset) - len_online
|
||||
weight_offline = 1.0
|
||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
||||
|
||||
# update the total number of samples used during sampling
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
|
||||
def train(cfg: dict, out_dir=None, job_name=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
@@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
||||
# if cfg.policy.balanced_sampling:
|
||||
# logging.info("make online_buffer")
|
||||
# num_traj_per_batch = cfg.policy.batch_size
|
||||
|
||||
# online_sampler = PrioritizedSliceSampler(
|
||||
# max_capacity=100_000,
|
||||
# alpha=cfg.policy.per_alpha,
|
||||
# beta=cfg.policy.per_beta,
|
||||
# num_slices=num_traj_per_batch,
|
||||
# strict_length=True,
|
||||
# )
|
||||
|
||||
# online_buffer = TensorDictReplayBuffer(
|
||||
# storage=LazyMemmapStorage(100_000),
|
||||
# sampler=online_sampler,
|
||||
# transform=dataset.transform,
|
||||
# )
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
@@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
logging.info(f"{cfg.env.action_repeat=}")
|
||||
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
@@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
def _maybe_eval_and_maybe_save(step):
|
||||
if step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
eval_info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
transform=dataset.transform,
|
||||
max_episodes_rendered=4,
|
||||
transform=offline_dataset.transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.save_model and step % cfg.save_freq == 0:
|
||||
@@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=True,
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
@@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||
# step + 1.
|
||||
@@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
step += 1
|
||||
|
||||
raise NotImplementedError()
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
rollout_env = make_env(cfg, num_parallel_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.data_dict = {}
|
||||
online_dataset.data_ids_per_episode = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
weights = [1.0] * len(concat_dataset)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
weights, num_samples=len(concat_dataset), replacement=True
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
||||
online_step = 0
|
||||
is_offline = False
|
||||
for env_step in range(cfg.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
# TODO: add configurable number of rollout? (default=1)
|
||||
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
max_steps=cfg.env.episode_length,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
policy,
|
||||
transform=offline_dataset.transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(rollout.batch_size) == 2
|
||||
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
|
||||
|
||||
num_parallel_env = rollout.batch_size[0]
|
||||
if num_parallel_env != 1:
|
||||
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
|
||||
raise NotImplementedError()
|
||||
|
||||
num_max_steps = rollout.batch_size[1]
|
||||
assert num_max_steps <= cfg.env.episode_length
|
||||
|
||||
# reshape to have a list of steps to insert into online_buffer
|
||||
rollout = rollout.reshape(num_parallel_env * num_max_steps)
|
||||
|
||||
# set same episode index for all time steps contained in this rollout
|
||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||
# online_buffer.extend(rollout)
|
||||
|
||||
ep_sum_reward = rollout["next", "reward"].sum()
|
||||
ep_max_reward = rollout["next", "reward"].max()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
rollout_info = {
|
||||
"avg_sum_reward": np.nanmean(ep_sum_reward),
|
||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
||||
"pc_success": np.nanmean(ep_success) * 100,
|
||||
"env_step": env_step,
|
||||
"ep_length": len(rollout),
|
||||
}
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
train_info = policy.update(
|
||||
# online_buffer,
|
||||
step,
|
||||
demo_buffer=demo_buffer,
|
||||
)
|
||||
policy.train()
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy(batch, step)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
train_info.update(rollout_info)
|
||||
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||
# in step + 1.
|
||||
|
||||
Reference in New Issue
Block a user