Fix bugs in multiple apps example 0e53

This commit is contained in:
Timothyxxx
2024-03-10 15:18:14 +08:00
parent 4645682b9e
commit e51d0e8cc9

View File

@@ -219,7 +219,10 @@ def compare_archive(pred_path: str, gold_path: str, **kwargs) -> float:
""" """
Compare two archives. Note that the files in the archives should be of the same type. Compare two archives. Note that the files in the archives should be of the same type.
""" """
if not pred_path: return 0. file_path = kwargs.pop('file_path', '')
if not pred_path:
return 0.
pred_folder = os.path.splitext(pred_path)[0] + '_pred' pred_folder = os.path.splitext(pred_path)[0] + '_pred'
gold_folder = os.path.splitext(gold_path)[0] + '_gold' gold_folder = os.path.splitext(gold_path)[0] + '_gold'
@@ -227,13 +230,16 @@ def compare_archive(pred_path: str, gold_path: str, **kwargs) -> float:
shutil.rmtree(pred_folder, ignore_errors=True) shutil.rmtree(pred_folder, ignore_errors=True)
os.makedirs(pred_folder) os.makedirs(pred_folder)
shutil.unpack_archive(pred_path, pred_folder) shutil.unpack_archive(pred_path, pred_folder)
if not os.path.exists(gold_folder): # use cache if exists if not os.path.exists(gold_folder): # use cache if exists
os.makedirs(gold_folder) os.makedirs(gold_folder)
shutil.unpack_archive(gold_path, gold_folder) shutil.unpack_archive(gold_path, gold_folder)
pred_files = sorted(os.listdir(pred_folder)) pred_files = sorted(os.listdir(os.path.join(pred_folder, file_path)))
gold_files = sorted(os.listdir(gold_folder)) gold_files = sorted(os.listdir(os.path.join(gold_folder, file_path)))
if pred_files != gold_files: return 0.
if pred_files != gold_files:
return 0.
def get_compare_function(): def get_compare_function():
file_type = kwargs.pop('file_type', 'text') file_type = kwargs.pop('file_type', 'text')
@@ -269,8 +275,8 @@ def compare_archive(pred_path: str, gold_path: str, **kwargs) -> float:
score = 0 score = 0
compare_function = get_compare_function() compare_function = get_compare_function()
for f1, f2 in zip(pred_files, gold_files): for f1, f2 in zip(pred_files, gold_files):
fp1 = os.path.join(pred_folder, f1) fp1 = os.path.join(pred_folder, file_path, f1)
fp2 = os.path.join(gold_folder, f2) fp2 = os.path.join(gold_folder, file_path, f2)
score += compare_function(fp1, fp2, **kwargs) score += compare_function(fp1, fp2, **kwargs)
return score / len(pred_files) return score / len(pred_files)
@@ -390,3 +396,16 @@ def is_added_to_steam_cart(active_tab_info, rule):
return 0. return 0.
return 1. return 1.
if __name__ == '__main__':
result = compare_archive(
r"C:\Users\tianbaox\Desktop\DesktopEnv\cache\0e5303d4-8820-42f6-b18d-daf7e633de21\lecture_slides.zip",
r"C:\Users\tianbaox\Desktop\DesktopEnv\cache\0e5303d4-8820-42f6-b18d-daf7e633de21\gold_lecture_slides.zip",
**{
"file_path": "lecture_slides",
"file_type": "pdf"
})
print(result)