Files
Search-R1/verl/utils/reward_score/countdown.py
PeterGriffinJin 068516be64 Initial commit
2025-02-28 15:16:19 +00:00

111 lines
3.7 KiB
Python

import re
import random
import ast
import operator
def extract_solution(solution_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
if "Assistant:" in solution_str:
solution_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
else:
return None
solution_str = solution_str.split('\n')[-1]
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
else:
final_answer = None
return final_answer
def validate_equation(equation_str, available_numbers):
"""Validate that equation only uses available numbers and each number once."""
try:
# Extract all numbers from the equation
numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)]
# Check if all numbers in equation are available
available_numbers = sorted(available_numbers)
numbers_in_eq = sorted(numbers_in_eq)
# Each number should be used exactly once
return numbers_in_eq == available_numbers
except:
return False
def evaluate_equation(equation_str):
"""Safely evaluate the arithmetic equation using eval() with precautions."""
try:
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r'^[\d+\-*/().\s]+$'
if not re.match(allowed_pattern, equation_str):
raise ValueError("Invalid characters in equation.")
# Evaluate the equation with restricted globals and locals
result = eval(equation_str, {"__builtins__": None}, {})
return result
except Exception as e:
return None
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
"""The scoring function for countdown task.
Args:
solution_str: the solution text
ground_truth: dictionary containing target number and available numbers
method: the method to extract the solution
format_score: the score for correct format but wrong answer
score: the score for the correct answer
"""
target = ground_truth['target']
numbers = ground_truth['numbers']
equation = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Target: {target} | Numbers: {numbers}")
print(f"Extracted equation: {equation}")
print(f"Solution string: {solution_str}")
if equation is None:
if do_print:
print(f"No equation found")
return 0
# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
if do_print:
print(f"Invalid equation")
return format_score
# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
if do_print:
print(f"Could not evaluate equation")
return format_score
if abs(result - target) < 1e-5: # Account for floating point precision
if do_print:
print(f"Correct equation: {equation} = {result}")
return score
else:
if do_print:
print(f"Wrong result: equation = {result}, target = {target}")
return format_score
except:
if do_print:
print(f"Error evaluating equation")
return format_score