Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
237 lines
10 KiB
Python
237 lines
10 KiB
Python
import datetime as dt
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Type
|
|
|
|
import draccus
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
from lerobot.common import envs
|
|
from lerobot.common.optim import OptimizerConfig
|
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
TRAIN_CONFIG_NAME = "train_config.json"
|
|
|
|
|
|
@dataclass
|
|
class OfflineConfig:
|
|
steps: int = 100_000
|
|
|
|
|
|
@dataclass
|
|
class OnlineConfig:
|
|
"""
|
|
The online training loop looks something like:
|
|
|
|
```python
|
|
for i in range(steps):
|
|
do_online_rollout_and_update_online_buffer()
|
|
for j in range(steps_between_rollouts):
|
|
batch = next(dataloader_with_offline_and_online_data)
|
|
loss = policy(batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
```
|
|
|
|
Note that the online training loop adopts most of the options from the offline loop unless specified
|
|
otherwise.
|
|
"""
|
|
|
|
steps: int = 0
|
|
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
|
rollout_n_episodes: int = 1
|
|
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
|
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
|
|
rollout_batch_size: int = 1
|
|
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
|
steps_between_rollouts: int | None = None
|
|
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
|
sampling_ratio: float = 0.5
|
|
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
|
env_seed: int | None = None
|
|
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
|
# FIFO.
|
|
buffer_capacity: int | None = None
|
|
# The minimum number of frames to have in the online buffer before commencing online training.
|
|
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
|
|
# seed size condition is satisfied.
|
|
buffer_seed_size: int = 0
|
|
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
|
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
|
# + eval + environment rendering simultaneously.
|
|
do_rollout_async: bool = False
|
|
|
|
def __post_init__(self):
|
|
if self.steps == 0:
|
|
return
|
|
|
|
if self.steps_between_rollouts is None:
|
|
raise ValueError(
|
|
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
|
|
)
|
|
if self.env_seed is None:
|
|
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
|
|
if self.buffer_capacity is None:
|
|
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
|
|
|
|
|
|
@dataclass
|
|
class TrainPipelineConfig(HubMixin):
|
|
dataset: DatasetConfig
|
|
env: envs.EnvConfig | None = None
|
|
policy: PreTrainedConfig | None = None
|
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
output_dir: Path | None = None
|
|
job_name: str | None = None
|
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
|
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
|
# regardless of what's provided with the training command at the time of resumption.
|
|
resume: bool = False
|
|
device: str | None = None # cuda | cpu | mp
|
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
# automatic gradient scaling is used.
|
|
use_amp: bool = False
|
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
|
# AND for the evaluation environments.
|
|
seed: int | None = 1000
|
|
# Number of workers for the dataloader.
|
|
num_workers: int = 4
|
|
batch_size: int = 8
|
|
eval_freq: int = 20_000
|
|
log_freq: int = 200
|
|
save_checkpoint: bool = True
|
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
|
save_freq: int = 20_000
|
|
offline: OfflineConfig = field(default_factory=OfflineConfig)
|
|
online: OnlineConfig = field(default_factory=OnlineConfig)
|
|
use_policy_training_preset: bool = True
|
|
optimizer: OptimizerConfig | None = None
|
|
scheduler: LRSchedulerConfig | None = None
|
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
|
|
def __post_init__(self):
|
|
self.checkpoint_path = None
|
|
|
|
def validate(self):
|
|
if not self.device:
|
|
logging.warning("No device specified, trying to infer device automatically")
|
|
device = auto_select_torch_device()
|
|
self.device = device.type
|
|
|
|
# Automatically deactivate AMP if necessary
|
|
if self.use_amp and not is_amp_available(self.device):
|
|
logging.warning(
|
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
)
|
|
self.use_amp = False
|
|
|
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
# Only load the policy config
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
elif self.resume:
|
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
config_path = parser.parse_arg("config_path")
|
|
if not config_path:
|
|
raise ValueError("A config_path is expected when resuming a run.")
|
|
policy_path = Path(config_path).parent
|
|
self.policy.pretrained_path = policy_path
|
|
self.checkpoint_path = policy_path.parent
|
|
|
|
if not self.job_name:
|
|
if self.env is None:
|
|
self.job_name = f"{self.policy.type}"
|
|
else:
|
|
self.job_name = f"{self.env.type}_{self.policy.type}"
|
|
|
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
|
raise FileExistsError(
|
|
f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
|
|
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
|
)
|
|
elif not self.output_dir:
|
|
now = dt.datetime.now()
|
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
|
self.output_dir = Path("outputs/train") / train_dir
|
|
|
|
if self.online.steps > 0:
|
|
if isinstance(self.dataset.repo_id, list):
|
|
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
|
if self.env is None:
|
|
raise ValueError("An environment is required for online training")
|
|
|
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
|
elif self.use_policy_training_preset and not self.resume:
|
|
self.optimizer = self.policy.get_optimizer_preset()
|
|
self.scheduler = self.policy.get_scheduler_preset()
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
draccus.dump(self, f, indent=4)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type["TrainPipelineConfig"],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**kwargs,
|
|
) -> "TrainPipelineConfig":
|
|
model_id = str(pretrained_name_or_path)
|
|
config_file: str | None = None
|
|
if Path(model_id).is_dir():
|
|
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
|
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
|
else:
|
|
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
elif Path(model_id).is_file():
|
|
config_file = model_id
|
|
else:
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=TRAIN_CONFIG_NAME,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
cli_args = kwargs.pop("cli_args", [])
|
|
cfg = draccus.parse(cls, config_file, args=cli_args)
|
|
|
|
return cfg
|