Files
backend/defect_detection_server.py
2025-07-01 18:06:39 +08:00

293 lines
11 KiB
Python

# {
# "success": true,
# "detection_data": {
# "defect_count": 5,
# "total_defect_area": 123.45,
# "total_crystal_area": 5000.0,
# "defect_score": 2.47,
# "algorithm_used": 1
# },
# "images": {
# "original_image": "base64_string",
# "binary_defects": "base64_string",
# "image_with_defects": "base64_string",
# "segmented_crystal": "base64_string"
# }
# }
import cv2
import numpy as np
import socket
import struct
import json
import base64
from io import BytesIO
from PIL import Image
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
class DefectDetectionServer:
def __init__(self, host='localhost', port=8888):
self.host = host
self.port = port
self.socket = None
def preprocess_image(self, image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
return blurred
def segment_crystal(self, blurred_image):
_, binary = cv2.threshold(blurred_image, 30, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
raise ValueError("No crystal detected in the image.")
crystal_contour = max(contours, key=cv2.contourArea)
mask = np.zeros_like(blurred_image)
cv2.drawContours(mask, [crystal_contour], -1, 255, thickness=cv2.FILLED)
segmented_crystal = cv2.bitwise_and(blurred_image, blurred_image, mask=mask)
return segmented_crystal, crystal_contour, binary
def detect_defects_GMM(self, segmented_crystal, crystal_contour):
# Extract the region of interest (ROI) using the crystal contour
mask = np.zeros_like(segmented_crystal)
cv2.drawContours(mask, [crystal_contour], -1, 255, thickness=cv2.FILLED)
roi = cv2.bitwise_and(segmented_crystal, segmented_crystal, mask=mask)
# Flatten the ROI to a 1D array of pixel intensities
pixel_intensities = roi.flatten()
pixel_intensities = pixel_intensities[pixel_intensities > 0] # Remove background pixels
if len(pixel_intensities) < 10:
raise ValueError("Not enough data points to fit Gaussian Mixture Model.")
# Reshape for GMM
X = pixel_intensities.reshape(-1, 1)
# Fit a Gaussian Mixture Model with two components
gmm = GaussianMixture(n_components=2, random_state=0).fit(X)
# Get the means and covariances of the fitted Gaussians
means = gmm.means_.flatten()
covars = gmm.covariances_.flatten()
# Determine which component corresponds to high brightness
high_brightness_mean_index = np.argmax(means)
high_brightness_mean = means[high_brightness_mean_index]
high_brightness_covar = covars[high_brightness_mean_index]
# Calculate the probability density function (PDF) values for each pixel intensity
pdf_values = gmm.score_samples(X)
# Set a threshold to identify high brightness regions
threshold = np.percentile(pdf_values, 98) # Adjust this threshold as needed
# Identify high brightness pixels
high_brightness_pixels = X[pdf_values >= threshold].flatten()
# Find contours corresponding to high brightness regions
_, binary_high_brightness = cv2.threshold(roi, int(high_brightness_mean), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary_high_brightness, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
defects = []
for contour in contours:
perimeter = cv2.arcLength(contour, True)
if perimeter > 5:
defects.append(contour)
# Create a black image with the same shape as the original image
binary_defects = np.zeros_like(segmented_crystal, dtype=np.uint8)
# Draw high brightness regions on the binary defects image
for y in range(segmented_crystal.shape[0]):
for x in range(segmented_crystal.shape[1]):
if mask[y, x] != 0 and segmented_crystal[y, x] >= high_brightness_mean:
binary_defects[y, x] = 255
return defects, binary_defects
def detect_defects(self, segmented_crystal, crystal_contour):
edges = cv2.Canny(segmented_crystal, 30, 90)
mask = np.zeros_like(edges)
cv2.drawContours(mask, [crystal_contour], -1, 255, thickness=cv2.FILLED)
inverted_mask = mask
defects_edges = cv2.bitwise_and(edges, inverted_mask)
contours, _ = cv2.findContours(defects_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
defects = []
for contour in contours:
perimeter = cv2.arcLength(contour, True)
if perimeter > 5:
defects.append(contour)
return defects, defects_edges
def calculate_total_area(self, crystal_contour):
total_area = cv2.contourArea(crystal_contour)
return total_area
def score_defects(self, defects, total_area):
defect_area = sum(cv2.contourArea(defect) for defect in defects)
score = (defect_area / total_area) * 100
return score
def image_to_base64(self, image):
"""Convert OpenCV image to base64 string"""
_, buffer = cv2.imencode('.png', image)
image_base64 = base64.b64encode(buffer).decode('utf-8')
return image_base64
def base64_to_image(self, base64_string):
"""Convert base64 string to OpenCV image"""
image_data = base64.b64decode(base64_string)
nparr = np.frombuffer(image_data, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return image
def process_image(self, image, algorithm):
"""Process image and return results"""
try:
blurred_image = self.preprocess_image(image)
segmented_crystal, crystal_contour, binary_image = self.segment_crystal(blurred_image)
if algorithm == 1:
defects, binary_defects = self.detect_defects(segmented_crystal, crystal_contour)
elif algorithm == 2:
defects, binary_defects = self.detect_defects_GMM(segmented_crystal, crystal_contour)
else:
raise ValueError("Invalid algorithm. Use 1 or 2.")
total_area = self.calculate_total_area(crystal_contour)
score = self.score_defects(defects, total_area)
# Draw defects on original image
image_with_defects = cv2.drawContours(image.copy(), defects, -1, (0, 255, 0), 2)
# Prepare detection data
detection_data = {
'defect_count': len(defects),
'total_defect_area': float(sum(cv2.contourArea(defect) for defect in defects)),
'total_crystal_area': float(total_area),
'defect_score': float(score),
'algorithm_used': algorithm
}
# Convert images to base64
original_image_b64 = self.image_to_base64(image)
binary_defects_b64 = self.image_to_base64(binary_defects)
image_with_defects_b64 = self.image_to_base64(image_with_defects)
segmented_crystal_b64 = self.image_to_base64(segmented_crystal)
result = {
'success': True,
'detection_data': detection_data,
'images': {
'original_image': original_image_b64,
'binary_defects': binary_defects_b64,
'image_with_defects': image_with_defects_b64,
'segmented_crystal': segmented_crystal_b64
}
}
return result
except Exception as e:
return {
'success': False,
'error': str(e)
}
def send_data(self, conn, data):
"""Send data with length prefix"""
json_data = json.dumps(data)
data_bytes = json_data.encode('utf-8')
data_length = len(data_bytes)
# Send length first (4 bytes)
conn.sendall(struct.pack('!I', data_length))
# Send data
conn.sendall(data_bytes)
def receive_data(self, conn):
"""Receive data with length prefix"""
# Receive length first (4 bytes)
length_data = b''
while len(length_data) < 4:
chunk = conn.recv(4 - len(length_data))
if not chunk:
return None
length_data += chunk
data_length = struct.unpack('!I', length_data)[0]
# Receive data
received_data = b''
while len(received_data) < data_length:
chunk = conn.recv(data_length - len(received_data))
if not chunk:
return None
received_data += chunk
return json.loads(received_data.decode('utf-8'))
def handle_client(self, conn, addr):
"""Handle client connection"""
print(f"Connected to {addr}")
try:
while True:
# Receive request from client
request = self.receive_data(conn)
if not request:
break
print(f"Received request from {addr}")
# Extract image and algorithm from request
image_b64 = request.get('image')
algorithm = request.get('algorithm', 1)
if not image_b64:
response = {'success': False, 'error': 'No image provided'}
else:
# Convert base64 to image
image = self.base64_to_image(image_b64)
if image is None:
response = {'success': False, 'error': 'Invalid image data'}
else:
# Process image
response = self.process_image(image, algorithm)
# Send response back to client
self.send_data(conn, response)
print(f"Sent response to {addr}")
except Exception as e:
print(f"Error handling client {addr}: {e}")
finally:
conn.close()
print(f"Connection to {addr} closed")
def start_server(self):
"""Start the socket server"""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
self.socket.bind((self.host, self.port))
self.socket.listen(5)
print(f"Server listening on {self.host}:{self.port}")
while True:
conn, addr = self.socket.accept()
self.handle_client(conn, addr)
except KeyboardInterrupt:
print("\nServer shutting down...")
except Exception as e:
print(f"Server error: {e}")
finally:
if self.socket:
self.socket.close()
if __name__ == "__main__":
server = DefectDetectionServer(host='172.0.01', port=8888)
server.start_server()