Files
sci-gui-agent-benchmark/mm_agents/os_symphony/agents/ocr.py
2025-12-23 14:30:44 +08:00

179 lines
6.2 KiB
Python
Executable File

import re
from io import BytesIO
from typing import Tuple, List, Dict
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pytesseract
from pytesseract import Output
import easyocr
class OCRProcessor:
"""
OCR Processor supports Tesseract and EasyOCR
"""
def __init__(self, use_gpu: bool = False, languages: List[str] = ['en']):
"""
Initialize processor
Args:
use_gpu (bool): whether EasyOCR need to use gpu
languages (List[str]): language list that EasyOCR, e.g. ['en', 'ch_sim']。
"""
self.use_gpu = use_gpu
self.languages = languages
self.reader = None # lazy-load EasyOCR Reader
def _get_easyocr_reader(self):
if self.reader is None:
print(f"Loading EasyOCR model (GPU={self.use_gpu})...")
self.reader = easyocr.Reader(self.languages, gpu=self.use_gpu)
return self.reader
def get_ocr_elements(self, bytes_image_data: bytes, mode: str = 'tesseract') -> Tuple[str, List[Dict]]:
"""
Executes OCR recognization.
Args:
bytes_image_data (str): image in Base64
mode (str): 'tesseract' (faster) or 'easyocr' (more precise)。
Returns:
Tuple[str, List]: (textual table string, list of element details)
"""
try:
image = Image.open(BytesIO(bytes_image_data))
except Exception as e:
print(f"Error decoding or opening image: {e}")
return "", []
if mode == 'tesseract':
return self._process_tesseract(image)
elif mode == 'easyocr':
return self._process_easyocr(image)
else:
raise ValueError(f"Unknown mode: {mode}. Use 'tesseract' or 'easyocr'.")
def _process_tesseract(self, image: Image.Image) -> Tuple[str, List[Dict]]:
"""Tesseract processing"""
data = pytesseract.image_to_data(image, output_type=Output.DICT)
ocr_elements = []
ocr_table = "Text Table (Tesseract):\nWord id\tText\n"
ocr_id = 0
num_boxes = len(data['text'])
for i in range(num_boxes):
# filter text with low confidence
if int(data['conf'][i]) > 0 and data['text'][i].strip():
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", data['text'][i])
if not clean_text: continue
ocr_table += f"{ocr_id}\t{clean_text}\n"
ocr_elements.append({
"id": ocr_id,
"text": clean_text,
"mode": "tesseract",
"left": data["left"][i],
"top": data["top"][i],
"width": data["width"][i],
"height": data["height"][i],
"conf": data["conf"][i]
})
ocr_id += 1
return ocr_table, ocr_elements
def _process_easyocr(self, image: Image.Image) -> Tuple[str, List[Dict]]:
"""EasyOCR processing"""
reader = self._get_easyocr_reader()
image_np = np.array(image)
# detail=1 means returning (bbox, text, conf)
results = reader.readtext(image_np, detail=1, paragraph=False, width_ths=0.1)
ocr_elements = []
ocr_table = "Text Table (EasyOCR):\nWord id\tText\n"
ocr_id = 0
for (bbox, text, conf) in results:
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", text)
if not clean_text.strip(): continue
# EasyOCR returns [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
# we convert them into left, top, width, height
(tl, tr, br, bl) = bbox
tl = [int(v) for v in tl]
br = [int(v) for v in br]
left = min(tl[0], bl[0])
top = min(tl[1], tr[1])
right = max(tr[0], br[0])
bottom = max(bl[1], br[1])
width = right - left
height = bottom - top
# ---------------
ocr_table += f"{ocr_id}\t{clean_text}\n"
ocr_elements.append({
"id": ocr_id,
"text": clean_text,
"mode": "easyocr",
"left": left,
"top": top,
"width": width,
"height": height,
"conf": float(conf)
})
ocr_id += 1
return ocr_table, ocr_elements
@staticmethod
def visualize_ocr_results(image_path: str, ocr_elements: List[Dict], output_path: str):
"""
Draw bounding boxes and IDs on the original image.
"""
try:
image = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 16)
except IOError:
font = ImageFont.load_default()
for element in ocr_elements:
left, top = element["left"], element["top"]
width, height = element["width"], element["height"]
color = "green" if element.get("mode") == "easyocr" else "red"
draw.rectangle([(left, top), (left + width, top + height)], outline=color, width=2)
text_str = str(element["id"])
if hasattr(draw, "textbbox"):
bbox = draw.textbbox((0, 0), text_str, font=font)
text_w, text_h = bbox[2]-bbox[0], bbox[3]-bbox[1]
else:
text_w, text_h = draw.textsize(text_str, font=font)
label_bg = [left, top - text_h - 4, left + text_w + 4, top]
draw.rectangle(label_bg, fill=color)
draw.text((left + 2, top - text_h - 4), text_str, fill="white", font=font)
image.save(output_path)
print(f"Visualization saved to: {output_path}")
except FileNotFoundError:
print(f"Error: Image {image_path} not found.")
except Exception as e:
print(f"Visualization error: {e}")