diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index e442490..6abd2b6 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -23,7 +23,7 @@ import re import numpy as np def _select_rm_score_fn(data_source): - if "nq" in data_source: + if data_source in ['nq', 'triviaqa', 'popqa', 'hotpotqa', '2wikimultihopqa', 'musique', 'bamboogle']: return qa_em.compute_score_em else: raise NotImplementedError