Merge branch 'main' into zdy
This commit is contained in:
@@ -81,7 +81,7 @@ class SetupController:
|
|||||||
downloaded = False
|
downloaded = False
|
||||||
for i in range(max_retries):
|
for i in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, stream=True, verify=False)
|
response = requests.get(url, stream=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(cache_path, 'wb') as f:
|
with open(cache_path, 'wb') as f:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def get_cloud_file(env, config: Dict[str, str]) -> str:
|
|||||||
return _path
|
return _path
|
||||||
|
|
||||||
url = config["path"]
|
url = config["path"]
|
||||||
response = requests.get(url, stream=True, verify=False)
|
response = requests.get(url, stream=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(_path, 'wb') as f:
|
with open(_path, 'wb') as f:
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
#import pylightxl
|
|
||||||
import openpyxl
|
import openpyxl
|
||||||
#from openpyxl import Workbook
|
|
||||||
from openpyxl.worksheet.worksheet import Worksheet
|
from openpyxl.worksheet.worksheet import Worksheet
|
||||||
|
|
||||||
from utils import load_charts, load_sparklines
|
from .utils import load_charts, load_sparklines
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def compare_table(actual, expected):
|
def compare_table(actual, expected):
|
||||||
df1 = pd.read_excel(expected)
|
df1 = pd.read_excel(expected)
|
||||||
df2 = pd.read_excel(actual)
|
df2 = pd.read_excel(actual)
|
||||||
@@ -16,6 +15,7 @@ def compare_table(actual, expected):
|
|||||||
# Compare the DataFrames
|
# Compare the DataFrames
|
||||||
return 1 if df1.equals(df2) else 0
|
return 1 if df1.equals(df2) else 0
|
||||||
|
|
||||||
|
|
||||||
def compare_with_sparklines(actual: str, expected: str) -> float:
|
def compare_with_sparklines(actual: str, expected: str) -> float:
|
||||||
df1 = pd.read_excel(actual)
|
df1 = pd.read_excel(actual)
|
||||||
df2 = pd.read_excel(expected)
|
df2 = pd.read_excel(expected)
|
||||||
@@ -29,6 +29,7 @@ def compare_with_sparklines(actual: str, expected: str) -> float:
|
|||||||
|
|
||||||
return float(normal_content_metric and sparkline_metric)
|
return float(normal_content_metric and sparkline_metric)
|
||||||
|
|
||||||
|
|
||||||
def compare_with_charts(actual: str, expected: str, **options) -> float:
|
def compare_with_charts(actual: str, expected: str, **options) -> float:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -45,25 +46,26 @@ def compare_with_charts(actual: str, expected: str, **options) -> float:
|
|||||||
|
|
||||||
charts1 = load_charts(actual, **options)
|
charts1 = load_charts(actual, **options)
|
||||||
charts2 = load_charts(expected, **options)
|
charts2 = load_charts(expected, **options)
|
||||||
chart_metric: bool = charts1==charts2
|
chart_metric: bool = charts1 == charts2
|
||||||
print("Chart Metric: {:}".format(chart_metric))
|
print("Chart Metric: {:}".format(chart_metric))
|
||||||
|
|
||||||
return float(normal_content_metric and chart_metric)
|
return float(normal_content_metric and chart_metric)
|
||||||
|
|
||||||
|
|
||||||
def check_sheet_list(result: str, rules: List[Dict[str, Any]]) -> float:
|
def check_sheet_list(result: str, rules: List[Dict[str, Any]]) -> float:
|
||||||
#workbook: Workbook = openpyxl.load_workbook(filename=result)
|
# workbook: Workbook = openpyxl.load_workbook(filename=result)
|
||||||
workbook = pd.ExcelFile(result)
|
workbook = pd.ExcelFile(result)
|
||||||
worksheet_names: List[str] = workbook.sheet_names
|
worksheet_names: List[str] = workbook.sheet_names
|
||||||
|
|
||||||
passes = True
|
passes = True
|
||||||
for r in rules:
|
for r in rules:
|
||||||
if r["type"]=="sheet_name":
|
if r["type"] == "sheet_name":
|
||||||
expected_name: str = worksheet_names[r["sheet_idx"]]
|
expected_name: str = worksheet_names[r["sheet_idx"]]
|
||||||
actual_name: str = r["sheet_name"]
|
actual_name: str = r["sheet_name"]
|
||||||
metric: bool = expected_name==actual_name
|
metric: bool = expected_name == actual_name
|
||||||
print("Assertion: {:d}.{:} is {:} - {:}".format(r["sheet_idx"], actual_name, expected_name, metric))
|
print("Assertion: {:d}.{:} is {:} - {:}".format(r["sheet_idx"], actual_name, expected_name, metric))
|
||||||
passes = passes and metric
|
passes = passes and metric
|
||||||
elif r["type"]=="sheet_data":
|
elif r["type"] == "sheet_data":
|
||||||
if isinstance(r["sheet_idx0"], int):
|
if isinstance(r["sheet_idx0"], int):
|
||||||
df1: pd.DataFrame = pd.read_excel(workbook, r["sheet_idx0"])
|
df1: pd.DataFrame = pd.read_excel(workbook, r["sheet_idx0"])
|
||||||
else:
|
else:
|
||||||
@@ -88,45 +90,47 @@ def check_sheet_list(result: str, rules: List[Dict[str, Any]]) -> float:
|
|||||||
|
|
||||||
return float(passes)
|
return float(passes)
|
||||||
|
|
||||||
|
|
||||||
def check_xlsx_freeze(result: str, rules: Dict[str, str]) -> float:
|
def check_xlsx_freeze(result: str, rules: Dict[str, str]) -> float:
|
||||||
worksheet: Worksheet = openpyxl.load_workbook(filename=result).active
|
worksheet: Worksheet = openpyxl.load_workbook(filename=result).active
|
||||||
return float(worksheet.freeze_panes==rules["position"])
|
return float(worksheet.freeze_panes == rules["position"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#path1 = ""
|
# path1 = ""
|
||||||
#path2 = ""
|
# path2 = ""
|
||||||
#print(compare_table(path1, path2))
|
# print(compare_table(path1, path2))
|
||||||
|
|
||||||
#path1 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart_gold.xlsx"
|
# path1 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart_gold.xlsx"
|
||||||
#path2 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart.xlsx"
|
# path2 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart.xlsx"
|
||||||
#print(compare_with_sparklines(path1, path2))
|
# print(compare_with_sparklines(path1, path2))
|
||||||
|
|
||||||
#path1 = "../../../../../任务数据/LibreOffice Calc/Freeze_row_column_gold.xlsx"
|
# path1 = "../../../../../任务数据/LibreOffice Calc/Freeze_row_column_gold.xlsx"
|
||||||
#path2 = "../../../../../任务数据/LibreOffice Calc/Freeze_row_column.xlsx"
|
# path2 = "../../../../../任务数据/LibreOffice Calc/Freeze_row_column.xlsx"
|
||||||
#workbook1: Workbook = openpyxl.load_workbook(filename=path1)
|
# workbook1: Workbook = openpyxl.load_workbook(filename=path1)
|
||||||
#worksheet1: Worksheet = workbook1.active
|
# worksheet1: Worksheet = workbook1.active
|
||||||
#print(worksheet1.freeze_panes)
|
# print(worksheet1.freeze_panes)
|
||||||
#workbook2: Workbook = openpyxl.load_workbook(filename=path2)
|
# workbook2: Workbook = openpyxl.load_workbook(filename=path2)
|
||||||
#worksheet2: Worksheet = workbook2.active
|
# worksheet2: Worksheet = workbook2.active
|
||||||
#print(worksheet2.freeze_panes)
|
# print(worksheet2.freeze_panes)
|
||||||
#rule = {"position": "C6"}
|
# rule = {"position": "C6"}
|
||||||
#print(check_xlsx_freeze(path1, rule))
|
# print(check_xlsx_freeze(path1, rule))
|
||||||
|
|
||||||
#path1 = "../../../../../任务数据/LibreOffice Calc/copy_sheet_insert_gold.xlsx"
|
# path1 = "../../../../../任务数据/LibreOffice Calc/copy_sheet_insert_gold.xlsx"
|
||||||
#rule = [ { "type": "sheet_name"
|
# rule = [ { "type": "sheet_name"
|
||||||
#, "sheet_idx": 0
|
# , "sheet_idx": 0
|
||||||
#, "sheet_name": "Sheet1"
|
# , "sheet_name": "Sheet1"
|
||||||
#}
|
# }
|
||||||
#, { "type": "sheet_data"
|
# , { "type": "sheet_data"
|
||||||
#, "sheet_idx0": "../../../../../任务数据/LibreOffice Calc/copy_sheet_insert.xlsx@0"
|
# , "sheet_idx0": "../../../../../任务数据/LibreOffice Calc/copy_sheet_insert.xlsx@0"
|
||||||
#, "sheet_idx1": 1
|
# , "sheet_idx1": 1
|
||||||
#}
|
# }
|
||||||
#, { "type": "sheet_name"
|
# , { "type": "sheet_name"
|
||||||
#, "sheet_idx": 2
|
# , "sheet_idx": 2
|
||||||
#, "sheet_name": "Sheet2"
|
# , "sheet_name": "Sheet2"
|
||||||
#}
|
# }
|
||||||
#]
|
# ]
|
||||||
#print(check_sheet_list(path1, rule))
|
# print(check_sheet_list(path1, rule))
|
||||||
|
|
||||||
path1 = "../../../../../任务数据/LibreOffice Calc/Create_column_charts_using_statistics_gold.xlsx"
|
path1 = "../../../../../任务数据/LibreOffice Calc/Create_column_charts_using_statistics_gold.xlsx"
|
||||||
path2 = "../../../../../任务数据/LibreOffice Calc/Create_column_charts_using_statistics_gold2.xlsx"
|
path2 = "../../../../../任务数据/LibreOffice Calc/Create_column_charts_using_statistics_gold2.xlsx"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
import requests
|
import requests
|
||||||
|
from .pyxcursor import Xcursor
|
||||||
# import Xlib.display
|
# import Xlib.display
|
||||||
import pyautogui
|
import pyautogui
|
||||||
# from PIL import ImageGrab, Image
|
# from PIL import ImageGrab, Image
|
||||||
@@ -48,7 +48,7 @@ def capture_screen_with_cursor():
|
|||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
# fixme: This is a temporary fix for the cursor not being captured on Windows and Linux
|
# fixme: This is a temporary fix for the cursor not being captured on Windows and Linux
|
||||||
if user_platform == "Windows" or user_platform == "Linux":
|
if user_platform == "Windows":
|
||||||
def _download_image(url, path):
|
def _download_image(url, path):
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
with open(path, 'wb') as file:
|
with open(path, 'wb') as file:
|
||||||
@@ -65,12 +65,14 @@ def capture_screen_with_cursor():
|
|||||||
cursor = cursor.resize((int(cursor.width / 1.5), int(cursor.height / 1.5)))
|
cursor = cursor.resize((int(cursor.width / 1.5), int(cursor.height / 1.5)))
|
||||||
screenshot.paste(cursor, (cursor_x, cursor_y), cursor)
|
screenshot.paste(cursor, (cursor_x, cursor_y), cursor)
|
||||||
screenshot.save(file_path)
|
screenshot.save(file_path)
|
||||||
# elif user_platform == "Linux":
|
elif user_platform == "Linux":
|
||||||
# # Use xlib to prevent scrot dependency for Linux
|
cursor_obj = Xcursor()
|
||||||
# screen = Xlib.display.Display().screen()
|
imgarray = cursor_obj.getCursorImageArrayFast()
|
||||||
# size = screen.width_in_pixels, screen.height_in_pixels
|
cursor_img = Image.fromarray(imgarray)
|
||||||
# screenshot = ImageGrab.grab(bbox=(0, 0, size[0], size[1]))
|
screenshot = pyautogui.screenshot()
|
||||||
# screenshot.save(file_path)
|
cursor_x, cursor_y = pyautogui.position()
|
||||||
|
screenshot.paste(cursor_img, (cursor_x, cursor_y), cursor_img)
|
||||||
|
screenshot.save(file_path)
|
||||||
elif user_platform == "Darwin": # (Mac OS)
|
elif user_platform == "Darwin": # (Mac OS)
|
||||||
# Use the screencapture utility to capture the screen with the cursor
|
# Use the screencapture utility to capture the screen with the cursor
|
||||||
subprocess.run(["screencapture", "-C", file_path])
|
subprocess.run(["screencapture", "-C", file_path])
|
||||||
@@ -161,7 +163,7 @@ def download_file():
|
|||||||
max_retries = 3
|
max_retries = 3
|
||||||
for i in range(max_retries):
|
for i in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, stream=True, verify=False)
|
response = requests.get(url, stream=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(path, 'wb') as f:
|
with open(path, 'wb') as f:
|
||||||
|
|||||||
146
desktop_env/server/pyxcursor.py
Normal file
146
desktop_env/server/pyxcursor.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import os
|
||||||
|
import ctypes
|
||||||
|
import ctypes.util
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# A helper function to convert data from Xlib to byte array.
|
||||||
|
import struct, array
|
||||||
|
|
||||||
|
# Define ctypes version of XFixesCursorImage structure.
|
||||||
|
PIXEL_DATA_PTR = ctypes.POINTER(ctypes.c_ulong)
|
||||||
|
Atom = ctypes.c_ulong
|
||||||
|
|
||||||
|
|
||||||
|
class XFixesCursorImage(ctypes.Structure):
|
||||||
|
"""
|
||||||
|
See /usr/include/X11/extensions/Xfixes.h
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
short x, y;
|
||||||
|
unsigned short width, height;
|
||||||
|
unsigned short xhot, yhot;
|
||||||
|
unsigned long cursor_serial;
|
||||||
|
unsigned long *pixels;
|
||||||
|
if XFIXES_MAJOR >= 2
|
||||||
|
Atom atom; /* Version >= 2 only */
|
||||||
|
const char *name; /* Version >= 2 only */
|
||||||
|
endif
|
||||||
|
} XFixesCursorImage;
|
||||||
|
"""
|
||||||
|
_fields_ = [('x', ctypes.c_short),
|
||||||
|
('y', ctypes.c_short),
|
||||||
|
('width', ctypes.c_ushort),
|
||||||
|
('height', ctypes.c_ushort),
|
||||||
|
('xhot', ctypes.c_ushort),
|
||||||
|
('yhot', ctypes.c_ushort),
|
||||||
|
('cursor_serial', ctypes.c_ulong),
|
||||||
|
('pixels', PIXEL_DATA_PTR),
|
||||||
|
('atom', Atom),
|
||||||
|
('name', ctypes.c_char_p)]
|
||||||
|
|
||||||
|
|
||||||
|
class Display(ctypes.Structure):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Xcursor:
|
||||||
|
display = None
|
||||||
|
|
||||||
|
def __init__(self, display=None):
|
||||||
|
if not display:
|
||||||
|
try:
|
||||||
|
display = os.environ["DISPLAY"].encode("utf-8")
|
||||||
|
except KeyError:
|
||||||
|
raise Exception("$DISPLAY not set.")
|
||||||
|
|
||||||
|
# XFixeslib = ctypes.CDLL('libXfixes.so')
|
||||||
|
XFixes = ctypes.util.find_library("Xfixes")
|
||||||
|
if not XFixes:
|
||||||
|
raise Exception("No XFixes library found.")
|
||||||
|
self.XFixeslib = ctypes.cdll.LoadLibrary(XFixes)
|
||||||
|
|
||||||
|
# xlib = ctypes.CDLL('libX11.so.6')
|
||||||
|
x11 = ctypes.util.find_library("X11")
|
||||||
|
if not x11:
|
||||||
|
raise Exception("No X11 library found.")
|
||||||
|
self.xlib = ctypes.cdll.LoadLibrary(x11)
|
||||||
|
|
||||||
|
# Define ctypes' version of XFixesGetCursorImage function
|
||||||
|
XFixesGetCursorImage = self.XFixeslib.XFixesGetCursorImage
|
||||||
|
XFixesGetCursorImage.restype = ctypes.POINTER(XFixesCursorImage)
|
||||||
|
XFixesGetCursorImage.argtypes = [ctypes.POINTER(Display)]
|
||||||
|
self.XFixesGetCursorImage = XFixesGetCursorImage
|
||||||
|
|
||||||
|
XOpenDisplay = self.xlib.XOpenDisplay
|
||||||
|
XOpenDisplay.restype = ctypes.POINTER(Display)
|
||||||
|
XOpenDisplay.argtypes = [ctypes.c_char_p]
|
||||||
|
|
||||||
|
if not self.display:
|
||||||
|
self.display = self.xlib.XOpenDisplay(display) # (display) or (None)
|
||||||
|
|
||||||
|
def argbdata_to_pixdata(self, data, len):
|
||||||
|
if data == None or len < 1: return None
|
||||||
|
|
||||||
|
# Create byte array
|
||||||
|
b = array.array('b', b'\x00' * 4 * len)
|
||||||
|
|
||||||
|
offset, i = 0, 0
|
||||||
|
while i < len:
|
||||||
|
argb = data[i] & 0xffffffff
|
||||||
|
rgba = (argb << 8) | (argb >> 24)
|
||||||
|
b1 = (rgba >> 24) & 0xff
|
||||||
|
b2 = (rgba >> 16) & 0xff
|
||||||
|
b3 = (rgba >> 8) & 0xff
|
||||||
|
b4 = rgba & 0xff
|
||||||
|
|
||||||
|
struct.pack_into("=BBBB", b, offset, b1, b2, b3, b4)
|
||||||
|
offset = offset + 4
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
return b
|
||||||
|
|
||||||
|
def getCursorImageData(self):
|
||||||
|
# Call the function. Read data of cursor/mouse-pointer.
|
||||||
|
cursor_data = self.XFixesGetCursorImage(self.display)
|
||||||
|
|
||||||
|
if not (cursor_data and cursor_data[0]):
|
||||||
|
raise Exception("Cannot read XFixesGetCursorImage()")
|
||||||
|
|
||||||
|
# Note: cursor_data is a pointer, take cursor_data[0]
|
||||||
|
return cursor_data[0]
|
||||||
|
|
||||||
|
def getCursorImageArray(self):
|
||||||
|
data = self.getCursorImageData()
|
||||||
|
# x, y = data.x, data.y
|
||||||
|
height, width = data.height, data.width
|
||||||
|
|
||||||
|
bytearr = self.argbdata_to_pixdata(data.pixels, height * width)
|
||||||
|
|
||||||
|
imgarray = np.array(bytearr, dtype=np.uint8)
|
||||||
|
imgarray = imgarray.reshape(height, width, 4)
|
||||||
|
del bytearr
|
||||||
|
|
||||||
|
return imgarray
|
||||||
|
|
||||||
|
def getCursorImageArrayFast(self):
|
||||||
|
data = self.getCursorImageData()
|
||||||
|
# x, y = data.x, data.y
|
||||||
|
height, width = data.height, data.width
|
||||||
|
|
||||||
|
bytearr = ctypes.cast(data.pixels, ctypes.POINTER(ctypes.c_ulong * height * width))[0]
|
||||||
|
imgarray = np.array(bytearray(bytearr))
|
||||||
|
imgarray = imgarray.reshape(height, width, 8)[:, :, (0, 1, 2, 3)]
|
||||||
|
del bytearr
|
||||||
|
|
||||||
|
return imgarray
|
||||||
|
|
||||||
|
def saveImage(self, imgarray, text):
|
||||||
|
from PIL import Image
|
||||||
|
img = Image.fromarray(imgarray)
|
||||||
|
img.save(text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cursor = Xcursor()
|
||||||
|
imgarray = cursor.getCursorImageArrayFast()
|
||||||
|
cursor.saveImage(imgarray, 'cursor_image.png')
|
||||||
@@ -4,3 +4,4 @@ Pillow==10.1.0
|
|||||||
git+https://github.com/moses-palmer/pynput.git@refs/pull/541/head # to make sure that it works on Apple Silicon
|
git+https://github.com/moses-palmer/pynput.git@refs/pull/541/head # to make sure that it works on Apple Silicon
|
||||||
requests
|
requests
|
||||||
flask
|
flask
|
||||||
|
numpy
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://101.43.24.67/s/LLE8tmGkpNeGBtZ/download/copy_sheet_insert.xlsx",
|
|
||||||
"path": "/home/david/copy_sheet_insert.xlsx"
|
"path": "/home/david/copy_sheet_insert.xlsx"
|
||||||
|
"url": "https://drive.usercontent.google.com/download?id=1ejNXBNOZtn64ugmvXot21pOEjx5xa-I5&export=download&authuser=0&confirm=t&uuid=61aa93e2-03f7-4b28-8e4a-cdff16a642f7&at=APZUnTVgPAHHfXaEjfKau5CDY1_K:1703509323791",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://101.43.24.67/s/wrEyMi8HsmFjQrZ/download/OrderId_Month_Chart.xlsx",
|
|
||||||
"path": "/home/david/OrderId_Month_Chart.xlsx"
|
"path": "/home/david/OrderId_Month_Chart.xlsx"
|
||||||
|
"url": "https://drive.usercontent.google.com/download?id=1uywX5XWMvesnb4-8LPKEzr2HFU7HmoIu&export=download&authuser=0&confirm=t&uuid=267bfe49-a861-4272-ae7c-39c95df35e84&at=APZUnTUbs-FF06hSMv3yWfdXc02l:1703508870351",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@
|
|||||||
"func": "compare_with_sparklines",
|
"func": "compare_with_sparklines",
|
||||||
"expected": {
|
"expected": {
|
||||||
"type": "cloud_file",
|
"type": "cloud_file",
|
||||||
"path": "https://101.43.24.67/s/t7pgJxNoAGFQWEM/download/OrderId_Month_Chart_gold.xlsx",
|
"path": "https://drive.usercontent.google.com/download?id=1KQJJLVPGtTL_7ArEWvwwbFbJSiA3cgSE&export=download&authuser=0&confirm=t&uuid=6b11c721-caad-439a-b369-4c13c7a485df&at=APZUnTV5-1isKrDKSHV9NeJ6TDeS:1703509054094",
|
||||||
"dest": "OrderId_Month_Chart_gold.xlsx"
|
"dest": "OrderId_Month_Chart_gold.xlsx"
|
||||||
},
|
},
|
||||||
"result": {
|
"result": {
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://101.43.24.67/s/s7aAngonFwaygHr/download/Create_column_charts_using_statistics.xlsx",
|
|
||||||
"path": "/home/david/Create_column_charts_using_statistics.xlsx"
|
"path": "/home/david/Create_column_charts_using_statistics.xlsx"
|
||||||
|
"url": "https://drive.usercontent.google.com/download?id=1GOEacGTLP4EfGS8YwO9aGmmPgud5EavT&export=download&authuser=0&confirm=t&uuid=3971675c-3a76-4f89-863f-7f8afa59c3c5&at=APZUnTWaQ4_l1IiXsAR8VbjKf4uZ:1703595929357",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -35,7 +35,7 @@
|
|||||||
},
|
},
|
||||||
"expected": {
|
"expected": {
|
||||||
"type": "cloud_file",
|
"type": "cloud_file",
|
||||||
"path": "https://101.43.24.67/s/SLL4CgyMiyre3Ss/download/Create_column_charts_using_statistics_gold.xlsx",
|
"path": "https://drive.usercontent.google.com/download?id=1yiTCGZvGccWET9u8K7looD3ybH7PO9gb&export=download&authuser=0&confirm=t&uuid=65f54a6f-bb2e-40c3-8a76-091d785a5aca&at=APZUnTVbeO6maMvzItLvSwdBEZoM:1703595892144",
|
||||||
"dest": "Create_column_charts_using_statistics_gold.xlsx"
|
"dest": "Create_column_charts_using_statistics_gold.xlsx"
|
||||||
},
|
},
|
||||||
"options": {
|
"options": {
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://101.43.24.67/s/H7k3zLNaNcMWyLB/download/Freeze_row_column.xlsx",
|
|
||||||
"path": "/home/david/Freeze_row_column.xlsx"
|
"path": "/home/david/Freeze_row_column.xlsx"
|
||||||
|
"url": "https://drive.usercontent.google.com/download?id=1ZhGLDKOden_oxzuN2xN9-jNQSHtCX6GE&export=download&authuser=0&confirm=t&uuid=2c097276-a610-4a9f-b6e4-5b54296c1555&at=APZUnTWc7zKPY_ykygn0mO1SAs4s:1703580957447",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"id": "f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
"id": "f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||||
"snapshot": "libreoffice_calc",
|
"snapshot": "libreoffice_calc",
|
||||||
"instruction": "Fill the missing row and column which show the total value",
|
"instruction": "Fill the missing row and column which show the total value",
|
||||||
"source": "https://youtube.com/shorts/feldd-Pn48c?si=9xJiem2uAHm6Jshb",
|
"source": "https://youtube.com/shorts/feldd-Pn48c?si=9xJiem2uAHm6Jshb",
|
||||||
"config": [
|
"config": [
|
||||||
{
|
{
|
||||||
"type": "download",
|
"type": "download",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "http://101.43.24.67/s/DbaHsQpPA7dxrA8/download/Quarterly_Product_Sales_by_Zone.xlsx",
|
|
||||||
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx"
|
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
|
"url": "https://drive.usercontent.google.com/download?id=1rwhniaClEkF8XFzdfaNUA6GmAiy4syMZ&export=download&authuser=0&confirm=t&uuid=6fdd5b04-85f4-45e1-ad74-368f8f2a82ab&at=APZUnTUP-JxPxLfNls6jXWghblQ5:1701766091851",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
277
mm_agents/SoM_agent.py
Normal file
277
mm_agents/SoM_agent.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
# fixme: Need to be rewrite on new action space
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
from mm_agents.gpt_4v_prompt import SYS_PROMPT
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
# seem
|
||||||
|
from seem.modeling.BaseModel import BaseModel as BaseModel_Seem
|
||||||
|
from seem.utils.distributed import init_distributed as init_distributed_seem
|
||||||
|
from seem.modeling import build_model as build_model_seem
|
||||||
|
from task_adapter.seem.tasks import interactive_seem_m2m_auto, inference_seem_pano, inference_seem_interactive
|
||||||
|
|
||||||
|
# semantic sam
|
||||||
|
from semantic_sam.BaseModel import BaseModel
|
||||||
|
from semantic_sam import build_model
|
||||||
|
from semantic_sam.utils.dist import init_distributed_mode
|
||||||
|
from semantic_sam.utils.arguments import load_opt_from_config_file
|
||||||
|
from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES
|
||||||
|
from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch
|
||||||
|
|
||||||
|
# sam
|
||||||
|
from segment_anything import sam_model_registry
|
||||||
|
from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto
|
||||||
|
from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m2m_interactive
|
||||||
|
|
||||||
|
from scipy.ndimage import label
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
'''
|
||||||
|
build args
|
||||||
|
'''
|
||||||
|
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
|
||||||
|
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
|
||||||
|
|
||||||
|
semsam_ckpt = "./swinl_only_sam_many2many.pth"
|
||||||
|
sam_ckpt = "./sam_vit_h_4b8939.pth"
|
||||||
|
seem_ckpt = "./seem_focall_v1.pt"
|
||||||
|
|
||||||
|
opt_semsam = load_opt_from_config_file(semsam_cfg)
|
||||||
|
opt_seem = load_opt_from_config_file(seem_cfg)
|
||||||
|
opt_seem = init_distributed_seem(opt_seem)
|
||||||
|
|
||||||
|
'''
|
||||||
|
build model
|
||||||
|
'''
|
||||||
|
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
|
||||||
|
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
|
||||||
|
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||||
|
model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs):
|
||||||
|
if slider < 1.5:
|
||||||
|
model_name = 'seem'
|
||||||
|
elif slider > 2.5:
|
||||||
|
model_name = 'sam'
|
||||||
|
else:
|
||||||
|
if mode == 'Automatic':
|
||||||
|
model_name = 'semantic-sam'
|
||||||
|
if slider < 1.5 + 0.14:
|
||||||
|
level = [1]
|
||||||
|
elif slider < 1.5 + 0.28:
|
||||||
|
level = [2]
|
||||||
|
elif slider < 1.5 + 0.42:
|
||||||
|
level = [3]
|
||||||
|
elif slider < 1.5 + 0.56:
|
||||||
|
level = [4]
|
||||||
|
elif slider < 1.5 + 0.70:
|
||||||
|
level = [5]
|
||||||
|
elif slider < 1.5 + 0.84:
|
||||||
|
level = [6]
|
||||||
|
else:
|
||||||
|
level = [6, 1, 2, 3, 4, 5]
|
||||||
|
else:
|
||||||
|
model_name = 'sam'
|
||||||
|
|
||||||
|
if label_mode == 'Alphabet':
|
||||||
|
label_mode = 'a'
|
||||||
|
else:
|
||||||
|
label_mode = '1'
|
||||||
|
|
||||||
|
text_size, hole_scale, island_scale = 640, 100, 100
|
||||||
|
text, text_part, text_thresh = '', '', '0.0'
|
||||||
|
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||||
|
semantic = False
|
||||||
|
|
||||||
|
if mode == "Interactive":
|
||||||
|
labeled_array, num_features = label(np.asarray(image['mask'].convert('L')))
|
||||||
|
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)])
|
||||||
|
|
||||||
|
if model_name == 'semantic-sam':
|
||||||
|
model = model_semsam
|
||||||
|
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
|
||||||
|
|
||||||
|
elif model_name == 'sam':
|
||||||
|
model = model_sam
|
||||||
|
if mode == "Automatic":
|
||||||
|
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||||
|
elif mode == "Interactive":
|
||||||
|
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||||
|
|
||||||
|
elif model_name == 'seem':
|
||||||
|
model = model_seem
|
||||||
|
if mode == "Automatic":
|
||||||
|
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||||
|
elif mode == "Interactive":
|
||||||
|
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Function to encode the image
|
||||||
|
def encode_image(image_path):
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_actions_from_string(input_string):
|
||||||
|
# Search for a JSON string within the input string
|
||||||
|
actions = []
|
||||||
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||||
|
try:
|
||||||
|
for match in matches:
|
||||||
|
action_dict = json.loads(match)
|
||||||
|
actions.append(action_dict)
|
||||||
|
return actions
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"Failed to parse JSON: {e}"
|
||||||
|
else:
|
||||||
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||||
|
try:
|
||||||
|
for match in matches:
|
||||||
|
action_dict = json.loads(match)
|
||||||
|
actions.append(action_dict)
|
||||||
|
return actions
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"Failed to parse JSON: {e}"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
action_dict = json.loads(input_string)
|
||||||
|
return [action_dict]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError("Invalid response format: " + input_string)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT4v_Agent:
|
||||||
|
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||||
|
self.instruction = instruction
|
||||||
|
self.model = model
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
self.trajectory = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": SYS_PROMPT
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def predict(self, obs):
|
||||||
|
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"])
|
||||||
|
base64_image = encode_image(obs)
|
||||||
|
self.trajectory.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
traj_to_show = []
|
||||||
|
for i in range(len(self.trajectory)):
|
||||||
|
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
||||||
|
if len(self.trajectory[i]["content"]) > 1:
|
||||||
|
traj_to_show.append("screenshot_obs")
|
||||||
|
print("Trajectory:", traj_to_show)
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self.trajectory,
|
||||||
|
"max_tokens": self.max_tokens
|
||||||
|
}
|
||||||
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||||
|
except:
|
||||||
|
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
||||||
|
actions = None
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def parse_actions(self, response: str):
|
||||||
|
# response example
|
||||||
|
"""
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action_type": "CLICK",
|
||||||
|
"click_type": "RIGHT"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# parse from the response
|
||||||
|
actions = parse_actions_from_string(response)
|
||||||
|
|
||||||
|
# add action into the trajectory
|
||||||
|
self.trajectory.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": response
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# parse action
|
||||||
|
parsed_actions = []
|
||||||
|
for action in actions:
|
||||||
|
parsed_action = {}
|
||||||
|
action_type = Action[action['action_type']].value
|
||||||
|
parsed_action["action_type"] = action_type
|
||||||
|
|
||||||
|
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||||
|
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||||
|
|
||||||
|
if action_type == Action.MOUSE_MOVE.value:
|
||||||
|
parsed_action["x"] = action["x"]
|
||||||
|
parsed_action["y"] = action["y"]
|
||||||
|
|
||||||
|
if action_type == Action.KEY.value:
|
||||||
|
parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys
|
||||||
|
|
||||||
|
if action_type == Action.TYPE.value:
|
||||||
|
parsed_action["text"] = action["text"]
|
||||||
|
|
||||||
|
parsed_actions.append(parsed_action)
|
||||||
|
|
||||||
|
return parsed_actions
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# OpenAI API Key
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||||
|
print(agent.predict(obs="stackoverflow.png"))
|
||||||
Reference in New Issue
Block a user