Initialize visual components such as SAM for assistance
This commit is contained in:
124
mm_agents/sam_test.py
Normal file
124
mm_agents/sam_test.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import SamModel, SamProcessor
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
||||
|
||||
def show_mask(mask, ax, random_color=False):
|
||||
if random_color:
|
||||
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||||
else:
|
||||
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
||||
h, w = mask.shape[-2:]
|
||||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||||
ax.imshow(mask_image)
|
||||
|
||||
|
||||
def show_box(box, ax):
|
||||
x0, y0 = box[0], box[1]
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
||||
|
||||
|
||||
def show_boxes_on_image(raw_image, boxes):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(raw_image)
|
||||
for box in boxes:
|
||||
show_box(box, plt.gca())
|
||||
plt.axis('on')
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_points_on_image(raw_image, input_points, input_labels=None):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(raw_image)
|
||||
input_points = np.array(input_points)
|
||||
if input_labels is None:
|
||||
labels = np.ones_like(input_points[:, 0])
|
||||
else:
|
||||
labels = np.array(input_labels)
|
||||
show_points(input_points, labels, plt.gca())
|
||||
plt.axis('on')
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(raw_image)
|
||||
input_points = np.array(input_points)
|
||||
if input_labels is None:
|
||||
labels = np.ones_like(input_points[:, 0])
|
||||
else:
|
||||
labels = np.array(input_labels)
|
||||
show_points(input_points, labels, plt.gca())
|
||||
for box in boxes:
|
||||
show_box(box, plt.gca())
|
||||
plt.axis('on')
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(raw_image)
|
||||
input_points = np.array(input_points)
|
||||
if input_labels is None:
|
||||
labels = np.ones_like(input_points[:, 0])
|
||||
else:
|
||||
labels = np.array(input_labels)
|
||||
show_points(input_points, labels, plt.gca())
|
||||
for box in boxes:
|
||||
show_box(box, plt.gca())
|
||||
plt.axis('on')
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_points(coords, labels, ax, marker_size=375):
|
||||
pos_points = coords[labels == 1]
|
||||
neg_points = coords[labels == 0]
|
||||
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
|
||||
linewidth=1.25)
|
||||
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
|
||||
linewidth=1.25)
|
||||
|
||||
|
||||
def show_masks_on_image(raw_image, masks, scores):
|
||||
if len(masks.shape) == 4:
|
||||
masks = masks.squeeze()
|
||||
if scores.shape[0] == 1:
|
||||
scores = scores.squeeze()
|
||||
|
||||
nb_predictions = scores.shape[-1]
|
||||
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
||||
|
||||
for i, (mask, score) in enumerate(zip(masks, scores)):
|
||||
mask = mask.cpu().detach()
|
||||
axes[i].imshow(np.array(raw_image))
|
||||
show_mask(mask, axes[i])
|
||||
axes[i].title.set_text(f"Mask {i + 1}, Score: {score.item():.3f}")
|
||||
axes[i].axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
plt.imshow(raw_image)
|
||||
|
||||
inputs = processor(raw_image, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
|
||||
|
||||
scores = outputs.iou_scores
|
||||
show_masks_on_image(raw_image, masks[0], scores)
|
||||
@@ -63,6 +63,8 @@ class DuckTrackEventActionConverter:
|
||||
|
||||
def scroll_event_to_action(self, event: dict):
|
||||
# TODO: need to confirm if df < 0 means scroll up or down
|
||||
|
||||
# TODO: NEED to be test to match the scroll up and down with our action, e.g. scroll here once is equal to scroll 10 or scroll 20?
|
||||
if event["dy"] < 0:
|
||||
down = False
|
||||
else:
|
||||
|
||||
34
utils/image_processing/contour.py
Normal file
34
utils/image_processing/contour.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Load the image
|
||||
image = cv2.imread('../../mm_agents/stackoverflow.png')
|
||||
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply adaptive thresholding to get a binary image
|
||||
thresh = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
|
||||
)
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Filter out contours that are not of cell size
|
||||
# This is done by assuming that cells will have a relatively standard size
|
||||
# The size filter is just a placeholder, real values depend on the actual image size
|
||||
min_cell_size = 500
|
||||
max_cell_size = 5000
|
||||
cell_contours = [cnt for cnt in contours if min_cell_size < cv2.contourArea(cnt) < max_cell_size]
|
||||
|
||||
# Draw contours on the image
|
||||
contour_output = image.copy()
|
||||
cv2.drawContours(contour_output, cell_contours, -1, (0, 255, 0), 2)
|
||||
|
||||
# Display the image with cell contours
|
||||
plt.figure(figsize=(12,6))
|
||||
plt.imshow(cv2.cvtColor(contour_output, cv2.COLOR_BGR2RGB))
|
||||
plt.title('Spreadsheet with Cell Contours')
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user