Refactor TD-MPC (#103)
Co-authored-by: Cadene <re.cadene@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
@@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if hasattr(policy, "ema") and policy.ema is not None:
|
||||
policy.ema.step(policy.diffusion)
|
||||
|
||||
if isinstance(policy, PolicyWithUpdate):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
@@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
if cfg.training.online_steps > 0:
|
||||
assert cfg.eval.batch_size == 1, "eval.batch_size > 1 not supported for online training steps"
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
@@ -305,7 +312,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
raise NotImplementedError("TD-MPC not implemented yet.")
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -361,12 +371,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.training.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
policy.train()
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
@@ -414,6 +424,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
@@ -422,17 +433,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
policy.train()
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
|
||||
Reference in New Issue
Block a user