Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,111 @@
import re
import random
import ast
import operator
def extract_solution(solution_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
if "Assistant:" in solution_str:
solution_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
else:
return None
solution_str = solution_str.split('\n')[-1]
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
else:
final_answer = None
return final_answer
def validate_equation(equation_str, available_numbers):
"""Validate that equation only uses available numbers and each number once."""
try:
# Extract all numbers from the equation
numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)]
# Check if all numbers in equation are available
available_numbers = sorted(available_numbers)
numbers_in_eq = sorted(numbers_in_eq)
# Each number should be used exactly once
return numbers_in_eq == available_numbers
except:
return False
def evaluate_equation(equation_str):
"""Safely evaluate the arithmetic equation using eval() with precautions."""
try:
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r'^[\d+\-*/().\s]+$'
if not re.match(allowed_pattern, equation_str):
raise ValueError("Invalid characters in equation.")
# Evaluate the equation with restricted globals and locals
result = eval(equation_str, {"__builtins__": None}, {})
return result
except Exception as e:
return None
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
"""The scoring function for countdown task.
Args:
solution_str: the solution text
ground_truth: dictionary containing target number and available numbers
method: the method to extract the solution
format_score: the score for correct format but wrong answer
score: the score for the correct answer
"""
target = ground_truth['target']
numbers = ground_truth['numbers']
equation = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Target: {target} | Numbers: {numbers}")
print(f"Extracted equation: {equation}")
print(f"Solution string: {solution_str}")
if equation is None:
if do_print:
print(f"No equation found")
return 0
# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
if do_print:
print(f"Invalid equation")
return format_score
# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
if do_print:
print(f"Could not evaluate equation")
return format_score
if abs(result - target) < 1e-5: # Account for floating point precision
if do_print:
print(f"Correct equation: {equation} = {result}")
return score
else:
if do_print:
print(f"Wrong result: equation = {result}, target = {target}")
return format_score
except:
if do_print:
print(f"Error evaluating equation")
return format_score

View File

@@ -0,0 +1,63 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def extract_solution(solution_str, method='strict'):
assert method in ['strict', 'flexible']
if method == 'strict':
# this also tests the formatting of the model
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None:
final_answer = None
else:
final_answer = solution.group(0)
final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
elif method == 'flexible':
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ['', '.']
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str, method=method)
if answer is None:
return 0
else:
if answer == ground_truth:
return score
else:
return format_score

View File

@@ -0,0 +1,227 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
def compute_score(solution_str, ground_truth) -> float:
retval = 0.
try:
string_in_last_boxed = last_boxed_only_string(solution_str)
if string_in_last_boxed is not None:
answer = remove_boxed(string_in_last_boxed)
if is_equiv(answer, ground_truth):
retval = 1.
except Exception as e:
print(e)
return retval
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[:len(left)] == left
return s[len(left):]
left = "\\boxed{"
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string

View File

@@ -0,0 +1,58 @@
import re
import random
def extract_solution(solution_str):
# Remove everything before the first "Assistant:"
if "Assistant:" in solution_str:
solution_str = solution_str.split("Assistant:", 1)[1]
else:
return None
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
else:
final_answer = None
if final_answer is not None:
try:
int_final_answer = int(final_answer)
except ValueError:
final_answer = None
return final_answer
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Ground truth: {ground_truth} | Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
if do_print:
print(f"No answer found")
return 0
else:
if int(answer) == int(ground_truth):
if do_print:
print(f"Correct answer: {answer}")
return score
else:
if do_print:
print(f"Incorrect answer {answer} | Ground truth: {ground_truth}")
return format_score

View File

@@ -0,0 +1,138 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import string
import random
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def em_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer == normalized_prediction:
score = 1
break
return score
def subem_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer in normalized_prediction:
score = 1
break
return score
def extract_solution(solution_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
# if "Assistant:" in solution_str:
# solution_str = solution_str.split("Assistant:", 1)[1]
# elif "<|im_start|>assistant" in solution_str:
# solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
# else:
# return None
# solution_str = solution_str.split('\n')[-1]
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str, re.DOTALL)
matches = list(match)
# If there are 0 or exactly 1 matches, return None
if len(matches) <= 1:
return None
# If there are 2 or more matches, return the last one
return matches[-1].group(1).strip()
def compute_score_em(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for exact match (EM).
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth['target']}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
return 0
else:
if em_check(answer, ground_truth['target']):
return score
else:
return format_score
def compute_score_subem(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for substring exact match (EM).
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth['target']}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
return 0
else:
if subem_check(answer, ground_truth['target']):
return score
else:
return format_score