Add online training with TD-MPC as proof of concept (#338)

This commit is contained in:
Alexander Soare
2024-07-25 11:16:38 +01:00
committed by GitHub
parent abbb1d2367
commit f8a6574698
25 changed files with 1291 additions and 233 deletions

View File

@@ -4,19 +4,30 @@ seed: 1
dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000
# TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5
online_env_seed: 10000
log_freq: 100
offline_steps: 50000
num_workers: 4
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 5000
log_freq: 100
online_steps: 50000
online_rollout_n_episodes: 1
online_rollout_batch_size: 1
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
# the last sampled episode.
online_steps_between_rollouts: 50
online_sampling_ratio: 0.5
online_env_seed: 10000
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
# 80000.
online_buffer_capacity: 80000
delta_timestamps:
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
@@ -31,6 +42,7 @@ policy:
# Input / output structure.
n_action_repeats: 2
horizon: 5
n_action_steps: 1
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?