Add online training with TD-MPC as proof of concept (#338)
This commit is contained in:
@@ -32,19 +32,54 @@ video_backend: pyav
|
||||
|
||||
training:
|
||||
offline_steps: ???
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
|
||||
# Number of workers for the offline training dataloader.
|
||||
num_workers: 4
|
||||
|
||||
batch_size: ???
|
||||
|
||||
eval_freq: ???
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||
# dataloader options. Unless otherwise specified.
|
||||
# The online training look looks something like:
|
||||
#
|
||||
# for i in range(online_steps):
|
||||
# do_online_rollout_and_update_online_buffer()
|
||||
# for j in range(online_steps_between_rollouts):
|
||||
# batch = next(dataloader_with_offline_and_online_data)
|
||||
# loss = policy(batch)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
#
|
||||
online_steps: ???
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
online_rollout_n_episodes: 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||
online_rollout_batch_size: 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
online_steps_between_rollouts: null
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
online_sampling_ratio: 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
online_env_seed: null
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
online_buffer_capacity: null
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
online_buffer_seed_size: 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_online_rollout_async: false
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
|
||||
Reference in New Issue
Block a user