Wxy/opencua (#274)
* OpenCUA Agent code base * update url * debug, modify url input * debug opencua * show result * debug agent history overlap * modify opencua agent; add comment lines * update parallel; clean code; use sleep 3s * ui-tars-0717
This commit is contained in:
@@ -3,7 +3,8 @@ import os
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
import logging
|
||||
from typing import Optional, Dict, List, Tuple, Union
|
||||
from loguru import logger
|
||||
|
||||
import ast
|
||||
@@ -573,7 +574,34 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='<point>x1 y1</point>'')\n\n## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
COMPUTER_USE_NO_THINKING = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(point='<point>x1 y1</point>')
|
||||
left_double(point='<point>x1 y1</point>')
|
||||
right_single(point='<point>x1 y1</point>')
|
||||
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
|
||||
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
|
||||
## Note
|
||||
- Use Chinese in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
class UITarsAgent:
|
||||
"""
|
||||
@@ -638,9 +666,11 @@ class UITarsAgent:
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
self.system_prompt = COMPUTER_USE_DOUBAO
|
||||
if use_thinking:
|
||||
self.system_prompt = COMPUTER_USE_DOUBAO
|
||||
else:
|
||||
self.system_prompt = COMPUTER_USE_NO_THINKING
|
||||
|
||||
|
||||
self.action_parse_res_factor = 1000
|
||||
self.model_type = "doubao"
|
||||
self.history_n = 5
|
||||
@@ -648,6 +678,9 @@ class UITarsAgent:
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.platform = "ubuntu"
|
||||
self.use_thinking = use_thinking
|
||||
|
||||
self.inference_func = self.inference_with_thinking if use_thinking else self.inference_without_thinking
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
@@ -721,7 +754,36 @@ class UITarsAgent:
|
||||
"details": response.text
|
||||
}
|
||||
|
||||
def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]:
|
||||
def inference_without_thinking(self, messages):
|
||||
api_key = os.environ['DOUBAO_API_KEY']
|
||||
api_url = os.environ['DOUBAO_API_URL']
|
||||
headers = {
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"thinking": {"type": "disabled"},
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
else:
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
print(response.json())
|
||||
return {
|
||||
"error": f"Request failed with status code {response.status_code}",
|
||||
"details": response.text
|
||||
}
|
||||
|
||||
def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
|
||||
"""Predict the next action based on the current observation."""
|
||||
|
||||
self.task_instruction = task_instruction
|
||||
@@ -793,7 +855,7 @@ class UITarsAgent:
|
||||
return prediction, ["FAIL"]
|
||||
try:
|
||||
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
||||
prediction = self.inference_with_thinking(messages)
|
||||
prediction = self.inference_func(messages)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error when fetching response from client, with error:\n{e}")
|
||||
|
||||
Reference in New Issue
Block a user