43 lines
1.1 KiB
YAML
43 lines
1.1 KiB
YAML
data:
|
|
train_batch_size: 256
|
|
micro_batch_size: 16 # this is also val batch size
|
|
train_files: ~/data/gsm8k/train.parquet
|
|
val_files: ~/data/gsm8k/test.parquet
|
|
prompt_key: question
|
|
response_key: answer
|
|
max_length: 1024
|
|
truncation: error
|
|
balance_dp_token: False
|
|
chat_template: null
|
|
model:
|
|
partial_pretrain: ~/models/gemma-1.1-7b-it
|
|
fsdp_config:
|
|
wrap_policy:
|
|
min_num_params: 0
|
|
cpu_offload: False
|
|
offload_params: False
|
|
external_lib: null
|
|
enable_gradient_checkpointing: False
|
|
trust_remote_code: False
|
|
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
|
|
lora_alpha: 16 # LoRA scaling factor
|
|
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
|
|
optim:
|
|
lr: 1e-5
|
|
betas: [0.9, 0.95]
|
|
weight_decay: 0.01
|
|
warmup_steps_ratio: 0.1
|
|
clip_grad: 1.0
|
|
|
|
trainer:
|
|
default_local_dir: /tmp/sft_model
|
|
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
|
|
resume_path: null
|
|
project_name: gsm8k-sft
|
|
experiment_name: test
|
|
total_epochs: 4
|
|
total_training_steps: null
|
|
validate_before_training: False
|
|
logger: ['console']
|
|
seed: 1
|