forked from tangger/lerobot
Add Pi0 (#681)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_dtype,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
@@ -86,6 +87,10 @@ def update_policy(
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if hasattr(policy, "update_ema_modules"):
|
||||
policy.update_ema_modules()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
@@ -215,6 +220,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
device=device,
|
||||
ds_meta=offline_dataset.meta,
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
@@ -296,6 +302,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
if hasattr(policy, "init_ema_modules"):
|
||||
policy.init_ema_modules()
|
||||
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.offline.steps):
|
||||
if offline_step == 0:
|
||||
@@ -306,7 +316,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
@@ -365,6 +376,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"task_index": {"shape": (), "dtype": np.dtype("int64")},
|
||||
# FIXME: 'task' is a string
|
||||
# "task": {"shape": (), "dtype": np.dtype("?")},
|
||||
# FIXME: 'next.success' is expected by pusht env but not xarm
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
@@ -451,9 +464,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
if len(offline_dataset.meta.tasks) > 1:
|
||||
raise NotImplementedError("Add support for multi task.")
|
||||
|
||||
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
total_num_frames = eval_info["episodes"]["index"].shape[0]
|
||||
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
|
||||
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
@@ -499,7 +513,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
dtype = get_safe_dtype(batch[key].dtype, device)
|
||||
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
|
||||
Reference in New Issue
Block a user