Merge branch 'main' into zdy

This commit is contained in:
David Chang
2024-03-19 17:47:23 +08:00
13 changed files with 120 additions and 71 deletions

View File

@@ -285,7 +285,7 @@ class DesktopEnv(gym.Env):
observation = {
"screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree(),
"terminal": self.controller.get_terminal_output(),
# "terminal": self.controller.get_terminal_output(),
"instruction": self.instruction
}

View File

@@ -430,11 +430,11 @@ def check_image_size(src_path, rule):
img = Image.open(src_path)
# Check the size
if rule["height"] is not None:
if rule.get("height", None) is not None:
height_same = img.size[1] == rule["height"]
else:
height_same = True
if rule["width"] is not None:
if rule.get("width", None) is not None:
width_same = img.size[0] == rule["width"]
else:
width_same = True
@@ -568,3 +568,51 @@ def check_image_file_size(src_path, rule):
return 1.0
else:
return 0.0
if __name__ == "__main__":
actual_config_path = "../../../cache/sessionrc_test"
rule = {
"key": "hide-docks",
"value": "no"
}
print(check_config_status(actual_config_path, rule))
actual_config_path = "../../../cache/action-history_test"
rule = {
"key": ["history-item", "\"filters-vignette\""],
"value": "1"
}
print(check_config_status(actual_config_path, rule))
actual_config_path = "../../../cache/gimprc_test"
rule = {
"key": "undo-levels",
"value": "100"
}
print(check_config_status(actual_config_path, rule))
src_path = "../../../cache/734d6579-c07d-47a8-9ae2-13339795476b/green_background_with_object.png"
tgt_path = "../../../cache/734d6579-c07d-47a8-9ae2-13339795476b/white_background_with_object.png"
print(check_green_background(src_path, tgt_path))
tgt_path = "../../../cache/f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce/Triangle_In_The_Middle.png"
print(check_triangle_position(tgt_path))
src_path = "../../../cache/bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108/anmi_sharper.png"
tgt_path = "../../../cache/bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108/anmi.png"
print(check_sharper(src_path, tgt_path))
src_path = "../../../cache/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f/compressed.jpeg"
rule = {
"max_size": 500000
}
print(check_image_file_size(src_path, rule))
src_path = "../../../cache/d16c99dc-2a1e-46f2-b350-d97c86c85c15/resized.png"
tgt_path = "../../../cache/d16c99dc-2a1e-46f2-b350-d97c86c85c15/dog_with_background.png"
rule = {
"height": 512
}
print(check_image_size(src_path, rule))
print(check_structure_sim_resized(src_path, tgt_path))

View File

@@ -236,6 +236,9 @@ def check_html_background_image(src_path: str, rule: Dict = None) -> float:
Check if the background image is correctly set.
multi-app:bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108
"""
if not src_path:
return 0.0
from bs4 import BeautifulSoup
with open(src_path, 'r') as f:
html_content = f.read()
@@ -252,6 +255,9 @@ def compare_result_files(src_path, tgt_path):
Compare whether the content of two files are the same.
multi-app:7f35355e-02a6-45b5-b140-f0be698bcf85
"""
if not src_path or not tgt_path:
return 0.0
with open(src_path, 'r') as f:
src_content = f.read().strip()
with open(tgt_path, 'r') as f:

View File

@@ -63,7 +63,7 @@ def execute_command():
# Execute the command without any safety checks.
try:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120)
return jsonify({
'status': 'success',
'output': result.stdout,

View File

@@ -86,13 +86,14 @@
],
"func": [
"check_image_size",
"check_structure_sim"
"check_structure_sim_resized"
],
"expected": [
{
"type": "vm_file",
"path": "/home/user/Desktop/dog_with_background.png",
"dest": "dog_with_background.png"
"type": "rule",
"rules": {
"height": 512
}
},
{
"type": "vm_file",
@@ -102,10 +103,9 @@
],
"result": [
{
"type": "rule",
"rules": {
"height": 512
}
"type": "vm_file",
"path": "/home/user/Desktop/dog_with_background.png",
"dest": "dog_with_background.png"
},
{
"type": "vm_file",

View File

@@ -63,6 +63,12 @@
"type": "vm_file",
"path": "/home/user/Desktop/saa-format-guide.pptx",
"dest": "saa-format-guide.pptx"
},
"expected": {
"type": "rule",
"rules": {
"color": "red"
}
}
}
}

View File

@@ -30,12 +30,12 @@
],
"evaluator": {
"func": "check_brightness_decrease_and_structure_sim",
"expected": {
"result": {
"type": "vm_file",
"path": "/home/user/Desktop/background.png",
"dest": "background.png"
},
"result": {
"expected": {
"type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=13if1UwZ5ay6ADAVW2jp3rcyvAEBse6MJ&export=download&authuser=0&confirm=t&uuid=2ea03068-1874-4240-baa1-f8bb2f917a99&at=APZUnTXq6dVlASg819jCaI1A-rm2:1710136385956",
"dest": "image_original.png"

View File

@@ -9,7 +9,7 @@
"parameters": {
"files": [
{
"url": "https://drive.usercontent.google.com/download?id=1e12nL_V7bffaLSocQ86EiGCdygzggWeu&export=download",
"url": "https://drive.usercontent.google.com/download?id=1epTcblcYh8j_wFtA-aiXPIF2Oo1IVw8A&export=download",
"path": "/home/user/Desktop/Dickinson_Slides.pptx"
}
]
@@ -36,7 +36,7 @@
},
"expected": {
"type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=1Xl6tgQ0K5qA1BDA2fKTK2xFLzXwbtkZ6&export=download",
"path": "https://drive.usercontent.google.com/download?id=1vUvaQLJUtFgbZi7lSzl0y0TS_WecFczm&export=download",
"dest": "notes_gold.docx"
},
"options": {

View File

@@ -1,19 +0,0 @@
import pandas as pd
file_path = "/Users/lxc/Downloads/Speedtest.csv"
# 找到csv第二行的第二个数据格里的值
# with open(file_path, "r") as f:
# for i, line in enumerate(f):
# if i == 1:
# data = line.split(",")[1]
# break
# print(data)
with open(file_path, "r") as f:
reader = pd.read_csv(f, sep=',', header=None)
# for column in reader.columns:
# if column.startswith("TEST_DATE"):
# data_col = column
# break
for data in reader['TEST_DATE']:
print(data)

View File

@@ -2,7 +2,7 @@ import datetime
import json
import logging
import os
import wandb
# import wandb
from wrapt_timeout_decorator import *
@@ -15,13 +15,13 @@ with open("./settings.json", "r") as file:
time_limit = data["time_limit"]
@timeout(time_limit, use_signals=False)
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores, run):
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
agent.reset()
obs = env.reset(task_config=example)
done = False
step_idx = 0
env.controller.start_recording()
str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
# str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
@@ -42,11 +42,11 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
screenshot = __f.read()
_f.write(screenshot)
# get a11tree and save to wandb
thisrun_a11tree = env.controller.get_accessibility_tree()
str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
thisrun_a11tree,
response, action, action_timestamp, done)
run.log({"Reward": reward})
# thisrun_a11tree = env.controller.get_accessibility_tree()
# str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
# thisrun_a11tree,
# response, action, action_timestamp, done)
# run.log({"Reward": reward})
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
@@ -62,11 +62,11 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
logger.info("The episode is done.")
break
step_idx += 1
run.log({"str_trajectory": str_table})
# run.log({"str_trajectory": str_table})
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
run.log({"Result": result})
# run.log({"Result": result})

View File

@@ -573,7 +573,7 @@ class PromptAgent:
top_p = payload["top_p"]
temperature = payload["temperature"]
misrtal_messages = []
mistral_messages = []
for i, message in enumerate(messages):
mistral_message = {
@@ -584,13 +584,8 @@ class PromptAgent:
for part in message["content"]:
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
misrtal_messages.append(mistral_message)
mistral_messages.append(mistral_message)
# openai.api_base = "http://localhost:8000/v1"
# response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
@@ -598,12 +593,23 @@ class PromptAgent:
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
response = client.chat.completions.create(
messages=misrtal_messages,
model=self.model,
max_tokens=max_tokens
)
flag = 0
while True:
try:
if flag > 20: break
response = client.chat.completions.create(
messages=mistral_messages,
model=self.model,
max_tokens=max_tokens
)
break
except:
if flag == 0:
mistral_messages = [mistral_messages[0]] + mistral_messages[-1:]
else:
mistral_messages[-1]["content"] = ' '.join(mistral_messages[-1]["content"].split()[:-500])
flag = flag + 1
try:
return response.choices[0].message.content

View File

@@ -48,4 +48,5 @@ easyocr
borb
pypdf2
pdfplumber
wandb
wrapt_timeout_decorator

25
run.py
View File

@@ -8,7 +8,7 @@ import logging
import os
import random
import sys
import wandb
# import wandb
from tqdm import tqdm
@@ -52,7 +52,8 @@ logger = logging.getLogger("desktopenv.experiment")
# wandb config
### set your wandb api key here
wandb.login(key=os.environ.get("WANDB_API_KEY", None))
# os.environ["WANDB_API_KEY"] = "48ec18fb4da7087238c6d6833eab9907565adbf3"
# wandb.login(key=os.environ.get("WANDB_API_KEY", None))
def config() -> argparse.Namespace:
@@ -148,8 +149,8 @@ def test(
for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
name=f"{example_id}")
# run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
# name=f"{example_id}")
# example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f:
@@ -164,7 +165,7 @@ def test(
# wandb each example config settings
cfg_args["instruction"] = instruction
cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S")
run.config.update(cfg_args)
# run.config.update(cfg_args)
example_result_dir = os.path.join(
args.result_dir,
@@ -178,10 +179,10 @@ def test(
# example start running
try:
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
scores, run)
scores)
except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}")
wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
# wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
@@ -189,11 +190,11 @@ def test(
}))
f.write("\n")
# wandb settings
os.mkdir(os.path.join(wandb.run.dir, "results/"))
for file in os.listdir(example_result_dir):
# move file to just under the root dir
os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
wandb.finish()
# os.mkdir(os.path.join(wandb.run.dir, "results/"))
# for file in os.listdir(example_result_dir):
# # move file to just under the root dir
# os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
# wandb.finish()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")