Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
Remi
2025-02-04 18:01:04 +01:00
committed by GitHub
parent dd974529cf
commit 638d411cd3
26 changed files with 2365 additions and 92 deletions

View File

@@ -38,6 +38,7 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_dtype,
get_safe_torch_device,
has_method,
init_logging,
@@ -86,6 +87,10 @@ def update_policy(
optimizer.zero_grad()
if hasattr(policy, "update_ema_modules"):
policy.update_ema_modules()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
@@ -215,6 +220,7 @@ def train(cfg: TrainPipelineConfig):
device=device,
ds_meta=offline_dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
@@ -296,6 +302,10 @@ def train(cfg: TrainPipelineConfig):
dl_iter = cycle(dataloader)
policy.train()
if hasattr(policy, "init_ema_modules"):
policy.init_ema_modules()
offline_step = 0
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
@@ -306,7 +316,8 @@ def train(cfg: TrainPipelineConfig):
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
policy,
@@ -365,6 +376,8 @@ def train(cfg: TrainPipelineConfig):
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'task' is a string
# "task": {"shape": (), "dtype": np.dtype("?")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
@@ -451,9 +464,10 @@ def train(cfg: TrainPipelineConfig):
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
@@ -499,7 +513,9 @@ def train(cfg: TrainPipelineConfig):
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
if isinstance(batch[key], torch.Tensor):
dtype = get_safe_dtype(batch[key].dtype, device)
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
train_info = update_policy(
policy,