Wxy/opencua (#274)
* OpenCUA Agent code base * update url * debug, modify url input * debug opencua * show result * debug agent history overlap * modify opencua agent; add comment lines * update parallel; clean code; use sleep 3s * ui-tars-0717
This commit is contained in:
@@ -571,11 +571,6 @@ class OpenCUAAgent:
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
image_bytes = BytesIO(obs['screenshot'])
|
||||
with Image.open(image_bytes) as img:
|
||||
print("Actual screen size", img.size)
|
||||
print("Logical screen size", self.screen_size)
|
||||
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "system",
|
||||
@@ -598,7 +593,7 @@ class OpenCUAAgent:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i]['action']
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
|
||||
messages.append({
|
||||
@@ -609,7 +604,7 @@ class OpenCUAAgent:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i]['action']
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
history_step_texts.append(history_content)
|
||||
if i == len(self.actions) - self.max_image_history_length:
|
||||
@@ -640,7 +635,7 @@ class OpenCUAAgent:
|
||||
"temperature": self.temperature
|
||||
}, self.model)
|
||||
|
||||
logger.info(f"Model Output: \n\n{response}")
|
||||
logger.info(f"Model Output: \n{response}")
|
||||
if not response:
|
||||
logger.error("No response found in the response.")
|
||||
return "ERROR", [], {}
|
||||
@@ -666,23 +661,23 @@ class OpenCUAAgent:
|
||||
self.cots.append(other_cot)
|
||||
|
||||
# Print message structure if needed
|
||||
logger.info(f"\nInstruction: {instruction}")
|
||||
messages_to_print = []
|
||||
current_image = 1
|
||||
for msg in messages:
|
||||
msg_copy = copy.deepcopy(msg)
|
||||
if isinstance(msg_copy['content'], list):
|
||||
for content in msg_copy['content']:
|
||||
if content['type'] == 'image_url':
|
||||
content['image_url']['url'] = f'Image {current_image}'
|
||||
current_image += 1
|
||||
messages_to_print.append(msg_copy)
|
||||
# messages_to_print = []
|
||||
# current_image = 1
|
||||
# for msg in messages:
|
||||
# msg_copy = copy.deepcopy(msg)
|
||||
# if isinstance(msg_copy['content'], list):
|
||||
# for content in msg_copy['content']:
|
||||
# if content['type'] == 'image_url':
|
||||
# content['image_url']['url'] = f'Image {current_image}'
|
||||
# current_image += 1
|
||||
# messages_to_print.append(msg_copy)
|
||||
|
||||
messages_to_print.append({
|
||||
"new_step_cot": other_cot,
|
||||
"response": response
|
||||
})
|
||||
logger.info(json.dumps(messages_to_print, indent=2))
|
||||
# messages_to_print.append({
|
||||
# "new_step_cot": other_cot,
|
||||
# "response": response
|
||||
# })
|
||||
# logger.info(json.dumps(messages_to_print, indent=2))
|
||||
logger.info(f"New step cot: {other_cot}")
|
||||
|
||||
return response, pyautogui_actions, {}
|
||||
|
||||
@@ -720,4 +715,10 @@ class OpenCUAAgent:
|
||||
logger.error("Retrying...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
response = response.json()
|
||||
finish_reason = response["choices"][0].get("finish_reason")
|
||||
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
|
||||
return response['choices'][0]['message']['content']
|
||||
else:
|
||||
logger.error("LLM did not finish properly, retrying...")
|
||||
time.sleep(5)
|
||||
|
||||
@@ -3,7 +3,8 @@ import os
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
import logging
|
||||
from typing import Optional, Dict, List, Tuple, Union
|
||||
from loguru import logger
|
||||
|
||||
import ast
|
||||
@@ -573,7 +574,34 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='<point>x1 y1</point>'')\n\n## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
COMPUTER_USE_NO_THINKING = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(point='<point>x1 y1</point>')
|
||||
left_double(point='<point>x1 y1</point>')
|
||||
right_single(point='<point>x1 y1</point>')
|
||||
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
|
||||
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
|
||||
## Note
|
||||
- Use Chinese in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
class UITarsAgent:
|
||||
"""
|
||||
@@ -638,9 +666,11 @@ class UITarsAgent:
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
self.system_prompt = COMPUTER_USE_DOUBAO
|
||||
if use_thinking:
|
||||
self.system_prompt = COMPUTER_USE_DOUBAO
|
||||
else:
|
||||
self.system_prompt = COMPUTER_USE_NO_THINKING
|
||||
|
||||
|
||||
self.action_parse_res_factor = 1000
|
||||
self.model_type = "doubao"
|
||||
self.history_n = 5
|
||||
@@ -648,6 +678,9 @@ class UITarsAgent:
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.platform = "ubuntu"
|
||||
self.use_thinking = use_thinking
|
||||
|
||||
self.inference_func = self.inference_with_thinking if use_thinking else self.inference_without_thinking
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
@@ -721,7 +754,36 @@ class UITarsAgent:
|
||||
"details": response.text
|
||||
}
|
||||
|
||||
def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]:
|
||||
def inference_without_thinking(self, messages):
|
||||
api_key = os.environ['DOUBAO_API_KEY']
|
||||
api_url = os.environ['DOUBAO_API_URL']
|
||||
headers = {
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"thinking": {"type": "disabled"},
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
else:
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
print(response.json())
|
||||
return {
|
||||
"error": f"Request failed with status code {response.status_code}",
|
||||
"details": response.text
|
||||
}
|
||||
|
||||
def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
|
||||
"""Predict the next action based on the current observation."""
|
||||
|
||||
self.task_instruction = task_instruction
|
||||
@@ -793,7 +855,7 @@ class UITarsAgent:
|
||||
return prediction, ["FAIL"]
|
||||
try:
|
||||
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
||||
prediction = self.inference_with_thinking(messages)
|
||||
prediction = self.inference_func(messages)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error when fetching response from client, with error:\n{e}")
|
||||
|
||||
Reference in New Issue
Block a user