Compare commits
10 Commits
231f7a8fbc
...
os_world
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
613f55f0da | ||
|
|
ba03784196 | ||
|
|
3890ee5fc3 | ||
|
|
9bc54c0a66 | ||
|
|
1e9281a1ab | ||
|
|
63484c7b7b | ||
|
|
ad46acc5f3 | ||
|
|
58d411bf86 | ||
|
|
be24e77d93 | ||
|
|
dd58a1de03 |
@@ -111,6 +111,7 @@ class DesktopEnv(gym.Env):
|
||||
os_type: str = "Ubuntu",
|
||||
enable_proxy: bool = False,
|
||||
client_password: str = "",
|
||||
eval_model: str = "gpt-5.2-chat-latest"
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -127,6 +128,7 @@ class DesktopEnv(gym.Env):
|
||||
require_terminal (bool): whether to require terminal output
|
||||
os_type (str): operating system type, default to "Ubuntu"
|
||||
enable_proxy (bool): whether to enable proxy support, default to False
|
||||
eval_model (str): evaluation model to use, default to "gpt-5.2-chat-latest"
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
@@ -179,6 +181,9 @@ class DesktopEnv(gym.Env):
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# Evaluation model
|
||||
self.eval_model = eval_model
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
@@ -425,7 +430,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
def evaluate(self, result_dir: Optional[str] = None):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
@@ -448,6 +453,20 @@ class DesktopEnv(gym.Env):
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
|
||||
if self.evaluator['func'] == "vllm_eval":
|
||||
logger.info("Preparing vllm_eval metric options...")
|
||||
screenshot_bytes = self.controller.get_screenshot()
|
||||
|
||||
import base64
|
||||
self.metric_options["instruction"] = self.instruction
|
||||
self.metric_options["eval_model"] = self.eval_model
|
||||
|
||||
if result_dir:
|
||||
self.metric_options["result_dir"] = result_dir
|
||||
logger.info(f"Using result_dir for vllm_eval: {result_dir}")
|
||||
|
||||
logger.info(f"Evaluation options prepared: {self.metric_options.keys()}")
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
results = []
|
||||
@@ -455,13 +474,18 @@ class DesktopEnv(gym.Env):
|
||||
if "expected" in self.evaluator:
|
||||
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
# Skip result state extraction if result_getter is None (e.g., for vllm_eval)
|
||||
if self.result_getter[idx] is not None:
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
else:
|
||||
# For evaluators that don't need result state (e.g., vllm_eval)
|
||||
result_state = None
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||
@@ -479,11 +503,16 @@ class DesktopEnv(gym.Env):
|
||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||
else:
|
||||
# Single metric to evaluate whether the task is successfully completed
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
# For evaluators like vllm_eval that don't need result_getter, skip result state extraction
|
||||
if self.result_getter is not None:
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
else:
|
||||
# For evaluators that don't need result state (e.g., vllm_eval)
|
||||
result_state = None
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||
|
||||
@@ -158,3 +158,5 @@ from .vscode import (
|
||||
|
||||
def infeasible():
|
||||
pass
|
||||
|
||||
from .vllm_eval import vllm_eval
|
||||
529
desktop_env/evaluators/metrics/vllm_eval.py
Normal file
529
desktop_env/evaluators/metrics/vllm_eval.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import base64
|
||||
import glob
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger("desktopenv.vllm_eval")
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _compress_image(img_b64: str, max_size: int = 800, quality: int = 85) -> str:
|
||||
"""
|
||||
Compress base64 encoded image to reduce size
|
||||
|
||||
Args:
|
||||
img_b64: Base64 encoded image string
|
||||
max_size: Maximum dimension (width or height) in pixels
|
||||
quality: JPEG quality (1-100), lower means smaller file size
|
||||
|
||||
Returns:
|
||||
Compressed base64 encoded image string
|
||||
"""
|
||||
try:
|
||||
# Decode base64 to image
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
|
||||
# Convert to RGB if necessary (for PNG with transparency)
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
|
||||
# Resize if image is too large
|
||||
original_size = img.size
|
||||
if max(img.size) > max_size:
|
||||
ratio = max_size / max(img.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in img.size)
|
||||
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||
logger.info(f"Resized image from {original_size} to {new_size}")
|
||||
|
||||
# Compress to JPEG
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=quality, optimize=True)
|
||||
compressed_data = buffer.getvalue()
|
||||
|
||||
# Encode back to base64
|
||||
compressed_b64 = base64.b64encode(compressed_data).decode('utf-8')
|
||||
|
||||
# Log compression ratio
|
||||
original_size_kb = len(img_b64) * 3 / 4 / 1024 # base64 to bytes to KB
|
||||
compressed_size_kb = len(compressed_b64) * 3 / 4 / 1024
|
||||
compression_ratio = (1 - compressed_size_kb / original_size_kb) * 100
|
||||
logger.info(f"Compressed image: {original_size_kb:.1f}KB -> {compressed_size_kb:.1f}KB ({compression_ratio:.1f}% reduction)")
|
||||
|
||||
return compressed_b64
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compress image: {e}, using original")
|
||||
return img_b64
|
||||
|
||||
|
||||
class UnifiedLLM:
|
||||
|
||||
def __init__(self, model: str):
|
||||
if model.startswith("gpt"):
|
||||
self.provider = "openai"
|
||||
elif model.startswith("claude"):
|
||||
self.provider = "anthropic"
|
||||
elif model.startswith("gemini"):
|
||||
self.provider = "gemini"
|
||||
else:
|
||||
self.provider = "unknown"
|
||||
|
||||
self.model = model
|
||||
self.client = self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""Initialize client"""
|
||||
if self.provider == "openai":
|
||||
from openai import OpenAI
|
||||
return OpenAI(
|
||||
base_url=os.getenv("OPENAI_BASE_URL"),
|
||||
api_key=os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
from anthropic import Anthropic
|
||||
return Anthropic(
|
||||
base_url=os.getenv("ANTHROPIC_BASE_URL"),
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY")
|
||||
)
|
||||
|
||||
elif self.provider == "gemini":
|
||||
logger.warning("Using Google Gemini model, make sure your internet connection is working.")
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
return genai.GenerativeModel(self.model)
|
||||
|
||||
else:
|
||||
logger.error(f"Unsupported LLM provider for model: {self.model}")
|
||||
raise ValueError(f"Unsupported LLM provider for model: {self.model}")
|
||||
|
||||
def _get_supported_params(self, temperature: float, max_tokens: int, top_p: float) -> Dict[str, Any]:
|
||||
"""Get supported parameters for each provider"""
|
||||
base_params = {
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
# GPT-5.2 and newer models may not support top_p
|
||||
if self.provider == "openai":
|
||||
# Only add top_p for older models
|
||||
if not self.model.startswith("gpt-5"):
|
||||
base_params["top_p"] = top_p
|
||||
elif self.provider == "anthropic":
|
||||
base_params["top_p"] = top_p
|
||||
elif self.provider == "gemini":
|
||||
base_params["top_p"] = top_p
|
||||
|
||||
return base_params
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 16384,
|
||||
top_p: float = 1.0,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
temperature: Temperature (0.0-2.0)
|
||||
max_tokens: Maximum number of tokens
|
||||
top_p: Top-p sampling (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||
|
||||
if self.provider == "openai":
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**params
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**params
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "gemini":
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
config = genai.GenerationConfig(
|
||||
temperature=params["temperature"],
|
||||
max_output_tokens=params["max_tokens"],
|
||||
top_p=params.get("top_p", 1.0)
|
||||
)
|
||||
response = self.client.generate_content(prompt, generation_config=config)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini API error: {e}")
|
||||
raise e
|
||||
|
||||
def generate_with_images(
|
||||
self,
|
||||
prompt: str,
|
||||
images_b64: List[str],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 16384,
|
||||
top_p: float = 1.0,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Generate with multiple images in a single request
|
||||
|
||||
Args:
|
||||
prompt: Instruction prompt
|
||||
images_b64: List of base64 encoded images
|
||||
temperature: Temperature (0.0-2.0)
|
||||
max_tokens: Maximum number of tokens
|
||||
top_p: Top-p sampling (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
if not images_b64:
|
||||
logger.warning("No images provided, falling back to text-only generation")
|
||||
return self.generate(prompt, temperature, max_tokens, top_p, **kwargs)
|
||||
|
||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||
|
||||
if self.provider == "openai":
|
||||
# Build content with text and all images
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
for img_b64 in images_b64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_b64}"
|
||||
}
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
**params
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
# Build content with text and all images
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
for img_b64 in images_b64:
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": img_b64
|
||||
}
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
**params
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "gemini":
|
||||
import google.generativeai as genai
|
||||
|
||||
config = genai.GenerationConfig(
|
||||
temperature=params["temperature"],
|
||||
max_output_tokens=params["max_tokens"],
|
||||
top_p=params.get("top_p", 1.0)
|
||||
)
|
||||
|
||||
# Build content parts
|
||||
content_parts = [prompt]
|
||||
|
||||
for img_b64 in images_b64:
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
content_parts.append(img)
|
||||
|
||||
try:
|
||||
response = self.client.generate_content(content_parts, generation_config=config)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini API error: {e}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
|
||||
def _load_screenshots_from_dir(result_dir: str, compress: bool = True, max_size: int = 800, quality: int = 85) -> List[str]:
|
||||
"""
|
||||
Load all step screenshots from result directory and convert to base64
|
||||
|
||||
Args:
|
||||
result_dir: Path to result directory containing step_*.png files
|
||||
compress: Whether to compress images (default: True)
|
||||
max_size: Maximum dimension for compression (default: 800)
|
||||
quality: JPEG quality for compression (default: 85)
|
||||
|
||||
Returns:
|
||||
List of base64 encoded screenshot strings
|
||||
"""
|
||||
screenshots = []
|
||||
|
||||
# Find all step screenshot files (e.g., step_1_20240101@120000.png)
|
||||
pattern = os.path.join(result_dir, "step_*.png")
|
||||
screenshot_files = sorted(glob.glob(pattern))
|
||||
|
||||
if not screenshot_files:
|
||||
logger.warning(f"No screenshot files found in {result_dir}")
|
||||
return screenshots
|
||||
|
||||
for filepath in screenshot_files:
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
img_data = f.read()
|
||||
img_b64 = base64.b64encode(img_data).decode('utf-8')
|
||||
|
||||
# Compress if enabled
|
||||
if compress:
|
||||
img_b64 = _compress_image(img_b64, max_size=max_size, quality=quality)
|
||||
|
||||
screenshots.append(img_b64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading screenshot {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(screenshots)} screenshots from {result_dir}")
|
||||
return screenshots
|
||||
|
||||
|
||||
def vllm_eval(result_state, **options) -> float:
|
||||
"""
|
||||
Evaluate task completion using vision-language model
|
||||
|
||||
Args:
|
||||
result_state: Current state description
|
||||
**options: Additional options including:
|
||||
- result_dir: Path to result directory containing step screenshots (recommended)
|
||||
- screenshots: List of base64 encoded screenshots (deprecated, use result_dir instead)
|
||||
- instruction: Task instruction
|
||||
- eval_model: Model name to use
|
||||
- compress_images: Whether to compress images (default: True)
|
||||
- max_image_size: Maximum image dimension for compression (default: 800)
|
||||
- image_quality: JPEG quality for compression (default: 85)
|
||||
- temperature: Temperature parameter
|
||||
- max_tokens: Maximum tokens
|
||||
- top_p: Top-p parameter
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
# Try to load screenshots from result_dir if provided
|
||||
result_dir = options.get("result_dir", None)
|
||||
screenshots = options.get("screenshots", [])
|
||||
|
||||
# Image compression options
|
||||
compress_images = options.get("compress_images", True)
|
||||
max_image_size = options.get("max_image_size", 800)
|
||||
image_quality = options.get("image_quality", 85)
|
||||
|
||||
if result_dir and not screenshots:
|
||||
screenshots = _load_screenshots_from_dir(
|
||||
result_dir,
|
||||
compress=compress_images,
|
||||
max_size=max_image_size,
|
||||
quality=image_quality
|
||||
)
|
||||
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
||||
elif screenshots:
|
||||
logger.info(f"Using {len(screenshots)} screenshots from options")
|
||||
# Compress screenshots if needed
|
||||
if compress_images:
|
||||
logger.info("Compressing provided screenshots...")
|
||||
screenshots = [_compress_image(img, max_size=max_image_size, quality=image_quality) for img in screenshots]
|
||||
|
||||
instruction = options.get("instruction", "")
|
||||
eval_model = options.get("eval_model", "gpt-4-vision-preview")
|
||||
|
||||
params = {
|
||||
"temperature": options.get("temperature", 0.7),
|
||||
"max_tokens": options.get("max_tokens", 16384),
|
||||
"top_p": options.get("top_p", 1.0)
|
||||
}
|
||||
|
||||
llm = UnifiedLLM(eval_model)
|
||||
|
||||
prompt = f"""You are an expert evaluator for desktop environment tasks.
|
||||
|
||||
Task Instruction: {instruction}
|
||||
|
||||
I will provide you with screenshot(s) showing the current state of the desktop environment. Please analyze the task execution step by step and provide a detailed evaluation.
|
||||
|
||||
IMPORTANT: You must respond with ONLY a valid JSON object (no additional text before or after). Use the following exact format:
|
||||
|
||||
{{
|
||||
"steps_analysis": [
|
||||
{{"step": "Step description", "status": "Success/Fail", "evidence_img": "step_X.png", "reason": "Brief explanation"}},
|
||||
{{"step": "Another step", "status": "Success/Fail", "evidence_img": "step_Y.png", "reason": "Brief explanation"}}
|
||||
],
|
||||
"final_completion": "True/False",
|
||||
"score": 0-10
|
||||
}}
|
||||
|
||||
Where:
|
||||
- "steps_analysis": Array of steps you identified from the screenshots (reference screenshot filenames like step_1.png, step_2.png, etc.)
|
||||
- "status": Either "Success" or "Fail" for each step
|
||||
- "evidence_img": The screenshot filename that shows evidence for this step (e.g., "step_2.png")
|
||||
- "reason": Brief explanation of why this step succeeded or failed
|
||||
- "final_completion": "True" if the overall task is completed, "False" otherwise
|
||||
- "score": Integer from 0 to 10, where 10 means perfectly completed and 0 means not completed at all
|
||||
|
||||
Remember: Return ONLY the JSON object, no additional text."""
|
||||
|
||||
try:
|
||||
result = llm.generate_with_images(
|
||||
prompt=prompt,
|
||||
images_b64=screenshots,
|
||||
**params
|
||||
)
|
||||
|
||||
# Parse score from result
|
||||
score = _parse_score(result)
|
||||
logger.info(f"Evaluation result: {result}")
|
||||
logger.info(f"Parsed score: {score}")
|
||||
|
||||
# Save raw result to file for reference
|
||||
if result_dir:
|
||||
eval_output_path = os.path.join(result_dir, "vllm_evaluation_result.json")
|
||||
with open(eval_output_path, "w", encoding="utf-8") as f:
|
||||
f.write(result)
|
||||
logger.info(f"Saved evaluation result to {eval_output_path}")
|
||||
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.error(f"Error during evaluation: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _parse_evaluation_response(text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the JSON evaluation response from the model
|
||||
|
||||
Returns:
|
||||
Dictionary containing steps_analysis, final_completion, and score
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
|
||||
# Try to extract JSON from the response
|
||||
# Sometimes models wrap JSON in markdown code blocks
|
||||
text = text.strip()
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
if text.startswith("```"):
|
||||
# Extract content between ``` markers
|
||||
match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(1)
|
||||
else:
|
||||
# Try to remove opening and closing ```
|
||||
text = re.sub(r'^```(?:json)?\s*', '', text)
|
||||
text = re.sub(r'\s*```$', '', text)
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
|
||||
# Validate required fields
|
||||
if "steps_analysis" not in result:
|
||||
logger.warning("Missing 'steps_analysis' field in response")
|
||||
result["steps_analysis"] = []
|
||||
|
||||
if "final_completion" not in result:
|
||||
logger.warning("Missing 'final_completion' field in response")
|
||||
result["final_completion"] = "False"
|
||||
|
||||
if "score" not in result:
|
||||
logger.warning("Missing 'score' field in response")
|
||||
result["score"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON response: {e}")
|
||||
logger.error(f"Response text: {text[:500]}")
|
||||
|
||||
# Return a default structure
|
||||
return {
|
||||
"steps_analysis": [],
|
||||
"final_completion": "False",
|
||||
"score": 0
|
||||
}
|
||||
|
||||
|
||||
def _parse_score(text: str) -> float:
|
||||
"""
|
||||
Parse score from model response and convert to 0.0-1.0 range
|
||||
|
||||
Args:
|
||||
text: Raw model response (expected to be JSON format)
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
result = _parse_evaluation_response(text)
|
||||
|
||||
# Extract score (0-10) and convert to 0.0-1.0
|
||||
score = result.get("score", 0)
|
||||
|
||||
try:
|
||||
score = float(score)
|
||||
# Clamp to [0, 10] then normalize to [0.0, 1.0]
|
||||
score = max(0.0, min(10.0, score))
|
||||
normalized_score = score / 10.0
|
||||
|
||||
logger.info(f"Final completion: {result.get('final_completion')}")
|
||||
logger.info(f"Raw score (0-10): {score}, Normalized score (0-1): {normalized_score}")
|
||||
|
||||
# Log steps analysis if available
|
||||
steps = result.get("steps_analysis", [])
|
||||
if steps:
|
||||
logger.info(f"Steps analysis ({len(steps)} steps):")
|
||||
for i, step in enumerate(steps):
|
||||
logger.info(f" Step {i+1}: {step.get('step', 'N/A')} - {step.get('status', 'N/A')}")
|
||||
|
||||
return normalized_score
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Could not parse score: {e}")
|
||||
return 0.0
|
||||
@@ -17,17 +17,10 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"func": "check_include_exclude",
|
||||
"func": "vllm_eval",
|
||||
"result": {
|
||||
"type": "vm_command_line",
|
||||
"command": "tasklist | findstr /i jade"
|
||||
},
|
||||
"expected": {
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"include": ["jade"],
|
||||
"exclude": []
|
||||
}
|
||||
}
|
||||
},
|
||||
"proxy": false,
|
||||
|
||||
604
evaluation_examples/extract_instructions.py
Normal file
604
evaluation_examples/extract_instructions.py
Normal file
@@ -0,0 +1,604 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import tempfile
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Configuration
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||||
|
||||
API_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||||
API_URL = f"{API_BASE_URL}/chat/completions" if API_BASE_URL else None
|
||||
API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
MODEL_NAME = "gemini-2.5-pro"
|
||||
MAX_CONCURRENT_REQUESTS = 5
|
||||
INPUT_FOLDER = "/Users/cuihang/Downloads/test_files"
|
||||
EXAMPLES_FOLDER = PROJECT_ROOT / "evaluation_examples" / "examples"
|
||||
TEST_ALL_JSON = PROJECT_ROOT / "evaluation_examples" / "test_all.json"
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRY_ATTEMPTS = 3
|
||||
RETRY_DELAY = 5
|
||||
RETRY_BACKOFF = 2
|
||||
|
||||
# Image limit
|
||||
MAX_IMAGES_PER_REQUEST = 50
|
||||
|
||||
# Supported file extensions
|
||||
SUPPORTED_EXTENSIONS = {'.docx', '.doc', '.ppt', '.pptx', '.pdf', '.mp4', '.avi', '.mov', '.mkv'}
|
||||
|
||||
SYSTEM_PROMPT = """You are an AI assistant that generates precise, executable step-by-step instructions for desktop software operations.
|
||||
|
||||
Your task:
|
||||
Convert the provided document information into precise operation instructions that can be executed step-by-step by an AI agent in a software GUI.
|
||||
|
||||
Output requirements (no additional explanatory text):
|
||||
------------------------------------------------
|
||||
|
||||
[Task Goal]
|
||||
Describe in one sentence the final task result to be achieved in the software.
|
||||
|
||||
[Input Files]
|
||||
Specify the file names, types, and locations involved in this operation.
|
||||
- If the document provides complete paths, record them as is
|
||||
- If only file names are mentioned (e.g., data.xlsx), record the filename and note "complete path not specified in document"
|
||||
- If no input files are mentioned, write "no input files required"
|
||||
|
||||
[Detailed Operation Steps (GUI Level)]
|
||||
Break down the task into atomic GUI operation steps.
|
||||
Each step must meet the following conditions:
|
||||
- Contains only one explicit, indivisible GUI atomic action
|
||||
- Must specify the menus, panels, buttons, or controls involved
|
||||
- Must specify parameter names and option values involved
|
||||
- Arranged in the actual operation order of the software
|
||||
- Must include software launch steps (e.g., double-click desktop icon, launch from start menu, etc.)
|
||||
|
||||
Step format example:
|
||||
1. Double-click the [Software Name] icon on the desktop to launch the software.
|
||||
2. Click "File → Open" in the main menu bar.
|
||||
3. In the file selection dialog, navigate to the specified directory and select file [filename].
|
||||
4. Click the "Open" button to confirm.
|
||||
5. ... (and so on)
|
||||
|
||||
------------------------------------------------
|
||||
|
||||
[Handling Uncertain Information]
|
||||
- Strictly generate operation steps based on document content, do not add features or menus not mentioned
|
||||
- If operation steps are unclear or ambiguous, infer based on common software operation flows
|
||||
- If parameter values in the document are unclear, note "[set according to actual needs]" in the step
|
||||
|
||||
[Output Format]
|
||||
Output in JSON format with the following fields:
|
||||
{
|
||||
"input_files": ["file1", "file2", "..."],
|
||||
"task_goal": "...",
|
||||
"steps": "A string containing all operation steps, arranged in order, with numbered prefix for each step, separated by newlines"
|
||||
}
|
||||
Note: Output must be strict JSON format, with no extra text or explanations."""
|
||||
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingStats:
|
||||
"""Processing statistics tracker"""
|
||||
total_files: int = 0
|
||||
completed_files: int = 0
|
||||
failed_files: int = 0
|
||||
retried_files: int = 0
|
||||
start_time: datetime = None
|
||||
failed_list: List[tuple] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.now()
|
||||
if self.failed_list is None:
|
||||
self.failed_list = []
|
||||
|
||||
def add_completed(self):
|
||||
self.completed_files += 1
|
||||
self._log_progress()
|
||||
|
||||
def add_failed(self, file_path: str, error: str):
|
||||
self.failed_files += 1
|
||||
self.failed_list.append((file_path, error))
|
||||
self._log_progress()
|
||||
|
||||
def add_retry(self):
|
||||
self.retried_files += 1
|
||||
|
||||
def _log_progress(self):
|
||||
processed = self.completed_files + self.failed_files
|
||||
percentage = (processed / self.total_files * 100) if self.total_files > 0 else 0
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
|
||||
if processed > 0:
|
||||
avg_time = elapsed / processed
|
||||
remaining = (self.total_files - processed) * avg_time
|
||||
eta = f"{int(remaining // 60)}m{int(remaining % 60)}s"
|
||||
else:
|
||||
eta = "calculating..."
|
||||
|
||||
logger.info(f"Progress: {processed}/{self.total_files} ({percentage:.1f}%) | "
|
||||
f"Success: {self.completed_files} | Failed: {self.failed_files} | "
|
||||
f"Retried: {self.retried_files} | ETA: {eta}")
|
||||
|
||||
def print_summary(self):
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
logger.info("=" * 60)
|
||||
logger.info("Processing Complete")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total files: {self.total_files}")
|
||||
logger.info(f"Success: {self.completed_files}")
|
||||
logger.info(f"Failed: {self.failed_files}")
|
||||
logger.info(f"Total retries: {self.retried_files}")
|
||||
logger.info(f"Total time: {int(elapsed // 60)}m{int(elapsed % 60)}s")
|
||||
|
||||
if self.failed_list:
|
||||
logger.info("\nFailed files:")
|
||||
for file_path, error in self.failed_list:
|
||||
logger.info(f" - {file_path}")
|
||||
logger.info(f" Error: {error}")
|
||||
|
||||
self._save_report()
|
||||
|
||||
def _save_report(self):
|
||||
report = {
|
||||
"total_files": self.total_files,
|
||||
"completed": self.completed_files,
|
||||
"failed": self.failed_files,
|
||||
"retries": self.retried_files,
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": datetime.now().isoformat(),
|
||||
"elapsed_seconds": (datetime.now() - self.start_time).total_seconds(),
|
||||
"failed_files": [{"file": f, "error": e} for f, e in self.failed_list]
|
||||
}
|
||||
|
||||
report_file = Path(EXAMPLES_FOLDER) / "processing_report.json"
|
||||
with open(report_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"\nDetailed report saved to: {report_file}")
|
||||
|
||||
|
||||
stats = ProcessingStats()
|
||||
software_tests = {}
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check and prompt for missing dependencies"""
|
||||
missing = []
|
||||
|
||||
try:
|
||||
import pdf2image
|
||||
except ImportError:
|
||||
missing.append("pdf2image")
|
||||
|
||||
try:
|
||||
import PIL
|
||||
except ImportError:
|
||||
missing.append("Pillow")
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
missing.append("opencv-python or opencv-python-headless")
|
||||
|
||||
if not shutil.which("soffice") and not shutil.which("libreoffice"):
|
||||
logger.warning("LibreOffice not detected, cannot convert .doc and .ppt files")
|
||||
logger.info("Install: sudo apt-get install libreoffice (Linux) or download from https://www.libreoffice.org/")
|
||||
|
||||
if missing:
|
||||
logger.error(f"Missing dependencies: {', '.join(missing)}")
|
||||
logger.info(f"Install with: pip install {' '.join(missing)}")
|
||||
logger.info("Note: pdf2image also requires poppler")
|
||||
logger.info(" - Ubuntu/Debian: sudo apt-get install poppler-utils")
|
||||
logger.info(" - macOS: brew install poppler")
|
||||
logger.info(" - Windows: download from https://github.com/oschwartz10612/poppler-windows/releases/")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_pdf_to_images(pdf_path: str) -> List[str]:
|
||||
"""Convert PDF to base64-encoded images"""
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
images = convert_from_path(pdf_path, dpi=150, fmt='jpeg')
|
||||
base64_images = []
|
||||
|
||||
for img in images:
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=100)
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
base64_images.append(img_base64)
|
||||
|
||||
return base64_images
|
||||
except Exception as e:
|
||||
logger.error(f"PDF conversion failed for {pdf_path}: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def convert_office_to_pdf(input_path: str) -> Optional[str]:
|
||||
"""Convert Office documents to PDF using LibreOffice"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
soffice_cmd = "soffice" if shutil.which("soffice") else "libreoffice"
|
||||
|
||||
cmd = [
|
||||
soffice_cmd,
|
||||
"--headless",
|
||||
"--convert-to", "pdf",
|
||||
"--outdir", temp_dir,
|
||||
input_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
if result.returncode == 0:
|
||||
pdf_name = Path(input_path).stem + ".pdf"
|
||||
pdf_path = os.path.join(temp_dir, pdf_name)
|
||||
|
||||
if os.path.exists(pdf_path):
|
||||
return pdf_path
|
||||
|
||||
logger.error(f"LibreOffice conversion failed: {result.stderr}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Office conversion failed for {input_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_document_to_images(file_path: str) -> List[str]:
|
||||
"""Convert any supported document to base64-encoded images"""
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
|
||||
if file_ext == '.pdf':
|
||||
return convert_pdf_to_images(file_path)
|
||||
|
||||
elif file_ext in ['.docx', '.doc', '.ppt', '.pptx']:
|
||||
pdf_path = convert_office_to_pdf(file_path)
|
||||
if pdf_path:
|
||||
images = convert_pdf_to_images(pdf_path)
|
||||
try:
|
||||
os.remove(pdf_path)
|
||||
os.rmdir(os.path.dirname(pdf_path))
|
||||
except:
|
||||
pass
|
||||
return images
|
||||
return []
|
||||
|
||||
elif file_ext in ['.mp4', '.avi', '.mov', '.mkv']:
|
||||
return extract_video_frames(file_path)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def extract_video_frames(video_path: str, num_frames: int = 10) -> List[str]:
|
||||
"""Extract key frames from video"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if total_frames == 0:
|
||||
return []
|
||||
|
||||
frame_indices = [int(total_frames * i / (num_frames + 1)) for i in range(1, num_frames + 1)]
|
||||
base64_frames = []
|
||||
|
||||
for idx in frame_indices:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
|
||||
if ret:
|
||||
height, width = frame.shape[:2]
|
||||
if width > 1280:
|
||||
scale = 1280 / width
|
||||
frame = cv2.resize(frame, (1280, int(height * scale)))
|
||||
|
||||
_, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||
base64_frames.append(frame_base64)
|
||||
|
||||
cap.release()
|
||||
return base64_frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Video frame extraction failed for {video_path}: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
async def call_api_single_batch(images_batch: List[str], file_type: str,
|
||||
session: aiohttp.ClientSession, batch_num: int = 0) -> tuple[str, bool, int]:
|
||||
"""
|
||||
Call API to process a single batch of images
|
||||
Returns: (content, success, status_code)
|
||||
"""
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
batch_info = f" (batch {batch_num})" if batch_num > 0 else ""
|
||||
content = [
|
||||
{"type": "text", "text": f"Please analyze the following {file_type} pages/frames{batch_info} and extract the operation workflow:"}
|
||||
]
|
||||
|
||||
for img_b64 in images_batch:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||||
})
|
||||
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": 8192
|
||||
}
|
||||
|
||||
async with session.post(API_URL, headers=headers, json=payload, timeout=180) as response:
|
||||
status_code = response.status
|
||||
if status_code == 200:
|
||||
result = await response.json()
|
||||
return result['choices'][0]['message']['content'], True, status_code
|
||||
else:
|
||||
error_text = await response.text()
|
||||
return f"[API call failed: {status_code}]\n{error_text}", False, status_code
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return "[API call timeout]", False, 0
|
||||
except Exception as e:
|
||||
return f"[API call error: {str(e)}]", False, 0
|
||||
|
||||
|
||||
async def call_multimodal_api_with_retry(file_path: str, session: aiohttp.ClientSession) -> tuple[str, bool]:
|
||||
"""
|
||||
Call multimodal API to analyze document images with retry mechanism
|
||||
Returns: (content, success)
|
||||
"""
|
||||
images_base64 = convert_document_to_images(file_path)
|
||||
|
||||
if not images_base64:
|
||||
error_msg = f"[Document conversion failed: unable to convert {Path(file_path).name} to images]"
|
||||
return error_msg, False
|
||||
|
||||
file_type = "video" if Path(file_path).suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv'] else "document"
|
||||
total_images = len(images_base64)
|
||||
|
||||
if total_images > MAX_IMAGES_PER_REQUEST:
|
||||
images_base64 = images_base64[:MAX_IMAGES_PER_REQUEST]
|
||||
total_images = MAX_IMAGES_PER_REQUEST
|
||||
|
||||
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
content, success, status_code = await call_api_single_batch(images_base64, file_type, session)
|
||||
|
||||
if success:
|
||||
return content, True
|
||||
|
||||
if status_code == 413:
|
||||
return f"[File too large: server refused to process the file]", False
|
||||
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = RETRY_DELAY * (RETRY_BACKOFF ** (attempt - 1))
|
||||
logger.info(f"\nRetry {attempt}/{MAX_RETRY_ATTEMPTS}: {Path(file_path).name} (waiting {delay}s)")
|
||||
stats.add_retry()
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
return content, False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = RETRY_DELAY * (RETRY_BACKOFF ** (attempt - 1))
|
||||
logger.info(f"\nRetry {attempt}/{MAX_RETRY_ATTEMPTS}: {Path(file_path).name} (timeout, waiting {delay}s)")
|
||||
stats.add_retry()
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
return "[API call timeout]", False
|
||||
|
||||
except Exception as e:
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = RETRY_DELAY * (RETRY_BACKOFF ** (attempt - 1))
|
||||
logger.info(f"\nRetry {attempt}/{MAX_RETRY_ATTEMPTS}: {Path(file_path).name} (error, waiting {delay}s)")
|
||||
stats.add_retry()
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
return f"[API call error: {str(e)}]", False
|
||||
|
||||
return "[Max retry attempts reached]", False
|
||||
|
||||
|
||||
async def process_file(file_path: str, session: aiohttp.ClientSession,
|
||||
semaphore: asyncio.Semaphore):
|
||||
"""Process a single file"""
|
||||
async with semaphore:
|
||||
try:
|
||||
content, success = await call_multimodal_api_with_retry(file_path, session)
|
||||
|
||||
file_path_obj = Path(file_path).resolve()
|
||||
input_folder_obj = Path(INPUT_FOLDER).resolve()
|
||||
|
||||
try:
|
||||
rel_path = file_path_obj.relative_to(input_folder_obj)
|
||||
software_name = rel_path.parts[0] if len(rel_path.parts) > 1 else "unknown"
|
||||
except ValueError:
|
||||
software_name = "unknown"
|
||||
|
||||
file_stem = file_path_obj.stem
|
||||
test_id = file_stem
|
||||
output_file = Path(EXAMPLES_FOLDER) / software_name / f"{file_stem}.json"
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
import re
|
||||
match = re.search(r'```json\s*([\s\S]*?)\s*```', content)
|
||||
content = match.group(1) if match else content
|
||||
|
||||
if success:
|
||||
api_result = json.loads(content)
|
||||
|
||||
data = {
|
||||
"id": test_id,
|
||||
"snapshot": "snapshot",
|
||||
"instruction": api_result.get("steps", ""),
|
||||
"source": "custom",
|
||||
"config": [],
|
||||
"trajectory": "trajectories/",
|
||||
"related_apps": [software_name],
|
||||
"evaluator": {
|
||||
"postconfig": [
|
||||
{
|
||||
"type": "sleep",
|
||||
"parameters": {
|
||||
"seconds": 3
|
||||
}
|
||||
}
|
||||
],
|
||||
"func": "vllm_eval"
|
||||
},
|
||||
"proxy": False,
|
||||
"fixed_ip": False,
|
||||
"possibility_of_env_change": "low",
|
||||
"metadata": {
|
||||
"input_files": api_result.get("input_files", []),
|
||||
"task_goal": api_result.get("task_goal", "")
|
||||
}
|
||||
}
|
||||
|
||||
if software_name not in software_tests:
|
||||
software_tests[software_name] = []
|
||||
software_tests[software_name].append(test_id)
|
||||
|
||||
else:
|
||||
data = {
|
||||
"id": test_id,
|
||||
"error": content,
|
||||
"status": "failed"
|
||||
}
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if success:
|
||||
stats.add_completed()
|
||||
else:
|
||||
stats.add_failed(file_path, content)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
stats.add_failed(file_path, error_msg)
|
||||
logger.error(f"\nError processing {file_path}: {error_msg}")
|
||||
|
||||
|
||||
def find_all_files(input_folder: str) -> List[str]:
|
||||
"""Recursively find all supported files"""
|
||||
all_files = []
|
||||
|
||||
for root, dirs, files in os.walk(input_folder):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
if Path(file_path).suffix.lower() in SUPPORTED_EXTENSIONS:
|
||||
all_files.append(file_path)
|
||||
|
||||
return all_files
|
||||
|
||||
|
||||
def save_test_all_json():
|
||||
"""Save aggregated test_all.json"""
|
||||
test_all_path = Path(TEST_ALL_JSON)
|
||||
if test_all_path.exists():
|
||||
with open(test_all_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
|
||||
for software, test_ids in software_tests.items():
|
||||
if software in existing_data:
|
||||
existing_data[software] = list(set(existing_data[software] + test_ids))
|
||||
else:
|
||||
existing_data[software] = test_ids
|
||||
|
||||
test_all_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(test_all_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"\nTest index updated: {test_all_path}")
|
||||
logger.info(f"Software included: {list(existing_data.keys())}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
if not check_dependencies():
|
||||
return
|
||||
|
||||
if not Path(INPUT_FOLDER).exists():
|
||||
logger.error(f"Input directory does not exist: {INPUT_FOLDER}")
|
||||
return
|
||||
|
||||
Path(EXAMPLES_FOLDER).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Scanning files...")
|
||||
logger.info(f"Input directory: {INPUT_FOLDER}")
|
||||
logger.info(f"Output directory: {EXAMPLES_FOLDER}")
|
||||
logger.info(f"Test index file: {TEST_ALL_JSON}\n")
|
||||
|
||||
files = find_all_files(INPUT_FOLDER)
|
||||
stats.total_files = len(files)
|
||||
|
||||
logger.info(f"Found {len(files)} files")
|
||||
logger.info(f"Configuration: max retries={MAX_RETRY_ATTEMPTS}, concurrency={MAX_CONCURRENT_REQUESTS}")
|
||||
logger.info("=" * 60 + "\n")
|
||||
|
||||
if not files:
|
||||
logger.warning("No supported files found")
|
||||
return
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [
|
||||
process_file(file, session, semaphore)
|
||||
for file in files
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
save_test_all_json()
|
||||
stats.print_summary()
|
||||
|
||||
logger.info("\nCompleted!")
|
||||
logger.info(f" - Test cases saved to: {EXAMPLES_FOLDER}")
|
||||
logger.info(f" - Test index updated: {TEST_ALL_JSON}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -13,17 +13,17 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
# Reset environment first to get fresh VM IP
|
||||
# env.reset(task_config=example)
|
||||
# logger.info("=======Environment reset completed=======")
|
||||
env.reset(task_config=example)
|
||||
logger.info("=======Environment reset completed=======")
|
||||
|
||||
# # Reset agent with fresh VM IP (for snapshot reverts)
|
||||
# try:
|
||||
# agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
# except Exception as e:
|
||||
# agent.reset(vm_ip=env.vm_ip)
|
||||
|
||||
|
||||
# time.sleep(10) # Wait for the environment to be ready
|
||||
|
||||
|
||||
# get initial observation
|
||||
logger.info("Getting initial observation...")
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
@@ -74,8 +74,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
break
|
||||
step_idx += 1
|
||||
time.sleep(20) # Wait for the environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
result = env.evaluate(result_dir=example_result_dir)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
@@ -112,7 +111,7 @@ def run_single_example_human(env, example, max_steps, instruction, args, example
|
||||
f.write("\n")
|
||||
|
||||
# Evaluate the result
|
||||
result = env.evaluate()
|
||||
result = env.evaluate(result_dir=example_result_dir)
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
@@ -549,7 +548,7 @@ def run_single_example_os_symphony(agent, env, example, max_steps, instruction,
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
result = float(env.evaluate(result_dir=example_result_dir))
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
@@ -647,7 +646,7 @@ def run_single_example_evocua(agent, env, example, max_steps, instruction, args,
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for environment to settle
|
||||
result = env.evaluate()
|
||||
result = env.evaluate(result_dir=example_result_dir)
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
||||
|
||||
4
run.py
4
run.py
@@ -85,7 +85,7 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
parser.add_argument("--max_steps", type=int, default=8)
|
||||
parser.add_argument("--enable_recording", action="store_true", help="Enable video recording (disabled by default)")
|
||||
|
||||
# agent config
|
||||
@@ -100,6 +100,7 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=16384)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
parser.add_argument("--eval_model", type=str, default="gpt-5.2-chat-latest")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
@@ -161,6 +162,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
os_type = "Windows",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
eval_model=args.eval_model
|
||||
)
|
||||
|
||||
# get actual VM screen size after environment initialization
|
||||
|
||||
Reference in New Issue
Block a user