flsol demo: fix top_p/claude/gemini, force coordinates, add reflection comments, screenshot mode
This commit is contained in:
@@ -752,7 +752,6 @@ class PromptAgent:
|
||||
elif self.model.startswith("claude"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
claude_messages = []
|
||||
@@ -796,11 +795,10 @@ class PromptAgent:
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://api.apiyi.com/v1/messages",
|
||||
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1").rstrip("/") + "/messages",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
@@ -816,7 +814,7 @@ class PromptAgent:
|
||||
elif self.model.startswith("mistral"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||
@@ -871,7 +869,7 @@ class PromptAgent:
|
||||
# THUDM/cogagent-chat-hf
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
cog_messages = []
|
||||
@@ -920,7 +918,7 @@ class PromptAgent:
|
||||
elif self.model in ["gemini-pro", "gemini-pro-vision"]:
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
if self.model == "gemini-pro":
|
||||
@@ -989,10 +987,10 @@ class PromptAgent:
|
||||
)
|
||||
return response.text
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
elif self.model in ["gemini-pro", "gemini-pro-vision", "gemini-1.5-pro", "gemini-1.5-flash"]:
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
gemini_messages = []
|
||||
@@ -1068,7 +1066,7 @@ class PromptAgent:
|
||||
elif self.model == "llama3-70b":
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||
@@ -1121,7 +1119,7 @@ class PromptAgent:
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
top_p = payload.get("top_p", 0.9)
|
||||
temperature = payload["temperature"]
|
||||
|
||||
qwen_messages = []
|
||||
@@ -1200,7 +1198,21 @@ class PromptAgent:
|
||||
return ""
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid model: " + self.model)
|
||||
# Fallback: openai-compatible for any unrecognized model (e.g. gemini-3.1 via apiyi)
|
||||
base_url = os.environ.get('OPENAI_BASE_URL', os.environ.get('OPENAI_API_BASE', 'https://api.openai.com'))
|
||||
api_url = f"{base_url}/chat/completions" if base_url.endswith('/v1') else f"{base_url}/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
}
|
||||
logger.info("Generating content with openai-compatible model: %s", self.model)
|
||||
response = requests.post(api_url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user