Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
@@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion)
if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
@@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
if cfg.training.online_steps > 0:
assert cfg.eval.batch_size == 1, "eval.batch_size > 1 not supported for online training steps"
init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
# Check device is available
get_safe_torch_device(cfg.device, log=True)
@@ -305,7 +312,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
raise NotImplementedError("TD-MPC not implemented yet.")
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -361,12 +371,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
)
dl_iter = cycle(dataloader)
policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.training.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
policy.train()
batch = next(dl_iter)
for key in batch:
@@ -414,6 +424,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
if env_step == 0:
logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
rollout_env,
@@ -422,17 +433,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
policy.train()
batch = next(dl_iter)
for key in batch: