Clean Code; Refactor README
This commit is contained in:
65
mm_agents/README.md
Normal file
65
mm_agents/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# Agent
|
||||
## Prompt-based Agents
|
||||
|
||||
### Supported Models
|
||||
We currently support the following models as the foundation models for the agents:
|
||||
- `GPT-3.5` (gpt-3.5-turbo-16k, ...)
|
||||
- `GPT-4` (gpt-4-0125-preview, gpt-4-1106-preview, ...)
|
||||
- `GPT-4V` (gpt-4-vision-preview, ...)
|
||||
- `Gemini-Pro`
|
||||
- `Gemini-Pro-Vision`
|
||||
- `Claude-3, 2` (claude-3-haiku-2024030, claude-3-sonnet-2024022, ...)
|
||||
- ...
|
||||
|
||||
And those from open-source community:
|
||||
- `Mixtral 8x7B`
|
||||
- `QWEN`, `QWEN-VL`
|
||||
- `CogAgent`
|
||||
- ...
|
||||
|
||||
And we will integrate and support more foundation models to support digital agent in the future, stay tuned.
|
||||
|
||||
### How to use
|
||||
|
||||
```python
|
||||
from mm_agents.agent import PromptAgent
|
||||
|
||||
agent = PromptAgent(
|
||||
model="gpt-4-0125-preview",
|
||||
observation_type="screenshot",
|
||||
)
|
||||
agent.reset()
|
||||
# say we have a instruction and observation
|
||||
instruction = "Please help me to find the nearest restaurant."
|
||||
obs = {"screenshot": "path/to/observation.jpg"}
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
```
|
||||
|
||||
### Observation Space and Action Space
|
||||
We currently support the following observation spaces:
|
||||
- `a11y_tree`: the a11y tree of the current screen
|
||||
- `screenshot`: a screenshot of the current screen
|
||||
- `screenshot_a11y_tree`: a screenshot of the current screen with a11y tree
|
||||
- `som`: the set-of-mark trick on the current screen, with a table metadata
|
||||
|
||||
And the following action spaces:
|
||||
- `pyautogui`: valid python code with `pyauotgui` code valid
|
||||
- `computer_13`: a set of enumerated actions designed by us
|
||||
|
||||
To use feed an observation into the agent, you have to keep the obs variable as a dict with the corresponding information:
|
||||
```python
|
||||
obs = {
|
||||
"screenshot": "path/to/observation.jpg",
|
||||
"a11y_tree": "" # [a11y_tree data]
|
||||
}
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
```
|
||||
|
||||
## Efficient Agents, Q* Agents, and more
|
||||
Stay tuned for more updates.
|
||||
@@ -180,6 +180,7 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||
linearized_accessibility_tree += "[...]\n"
|
||||
return linearized_accessibility_tree
|
||||
|
||||
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -572,22 +573,10 @@ class PromptAgent:
|
||||
|
||||
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
|
||||
|
||||
# headers = {
|
||||
# "x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||
# "anthropic-version": "2023-06-01",
|
||||
# "content-type": "application/json"
|
||||
# }
|
||||
|
||||
# headers = {
|
||||
# "Accept": "application / json",
|
||||
# "Authorization": "Bearer " + os.environ["ANTHROPIC_API_KEY"],
|
||||
# "User-Agent": "Apifox/1.0.0 (https://apifox.com)",
|
||||
# "Content-Type": "application/json"
|
||||
# }
|
||||
|
||||
headers = {
|
||||
"Authorization": os.environ["ANTHROPIC_API_KEY"],
|
||||
"Content-Type": "application/json"
|
||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -598,28 +587,21 @@ class PromptAgent:
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||
json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()['choices'][0]['message']['content']
|
||||
break
|
||||
else:
|
||||
logger.error(f"Failed to call LLM: {response.text}")
|
||||
time.sleep(10)
|
||||
attempt += 1
|
||||
response = requests.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
print("Exceeded maximum attempts to call LLM.")
|
||||
result = ""
|
||||
|
||||
return result
|
||||
|
||||
return response.json()['content'][0]['text']
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
print("Call mistral")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
@@ -652,7 +634,9 @@ class PromptAgent:
|
||||
response = client.chat.completions.create(
|
||||
messages=mistral_messages,
|
||||
model=self.model,
|
||||
max_tokens=max_tokens
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
break
|
||||
except:
|
||||
@@ -670,7 +654,6 @@ class PromptAgent:
|
||||
|
||||
elif self.model.startswith("THUDM"):
|
||||
# THUDM/cogagent-chat-hf
|
||||
print("Call CogAgent")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
@@ -703,7 +686,9 @@ class PromptAgent:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": cog_messages
|
||||
"messages": cog_messages,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
base_url = "http://127.0.0.1:8000"
|
||||
@@ -717,7 +702,6 @@ class PromptAgent:
|
||||
print("Failed to call LLM: ", response.status_code)
|
||||
return ""
|
||||
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
@@ -802,7 +786,8 @@ class PromptAgent:
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
if payload["temperature"]:
|
||||
logger.warning("Qwen model does not support temperature parameter, it will be ignored.")
|
||||
|
||||
qwen_messages = []
|
||||
|
||||
@@ -821,7 +806,9 @@ class PromptAgent:
|
||||
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model='qwen-vl-plus',
|
||||
messages=messages, # todo: add the hyperparameters
|
||||
messages=messages,
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
# The response status_code is HTTPStatus.OK indicate success,
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB |
Reference in New Issue
Block a user