feat(evaluator): add vision-language model evaluator
This commit is contained in:
517
desktop_env/evaluators/metrics/vllm_eval.py
Normal file
517
desktop_env/evaluators/metrics/vllm_eval.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import base64
|
||||
import glob
|
||||
|
||||
logger = logging.getLogger("desktopenv.vllm_eval")
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class UnifiedLLM:
|
||||
|
||||
def __init__(self, model: str):
|
||||
if model.startswith("gpt"):
|
||||
self.provider = "openai"
|
||||
elif model.startswith("claude"):
|
||||
self.provider = "anthropic"
|
||||
elif model.startswith("gemini"):
|
||||
self.provider = "gemini"
|
||||
else:
|
||||
self.provider = "unknown"
|
||||
|
||||
self.model = model
|
||||
self.client = self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""Initialize client"""
|
||||
if self.provider == "openai":
|
||||
from openai import OpenAI
|
||||
return OpenAI(
|
||||
base_url=os.getenv("OPENAI_BASE_URL"),
|
||||
api_key=os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
from anthropic import Anthropic
|
||||
return Anthropic(
|
||||
base_url=os.getenv("ANTHROPIC_BASE_URL"),
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY")
|
||||
)
|
||||
|
||||
elif self.provider == "gemini":
|
||||
logger.warning("Using Google Gemini model, make sure your internet connection is working.")
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
return genai.GenerativeModel(self.model)
|
||||
|
||||
else:
|
||||
logger.error(f"Unsupported LLM provider for model: {self.model}")
|
||||
raise ValueError(f"Unsupported LLM provider for model: {self.model}")
|
||||
|
||||
def _get_supported_params(self, temperature: float, max_tokens: int, top_p: float) -> Dict[str, Any]:
|
||||
"""Get supported parameters for each provider"""
|
||||
base_params = {
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
# GPT-5.2 and newer models may not support top_p
|
||||
if self.provider == "openai":
|
||||
# Only add top_p for older models
|
||||
if not self.model.startswith("gpt-5"):
|
||||
base_params["top_p"] = top_p
|
||||
elif self.provider == "anthropic":
|
||||
base_params["top_p"] = top_p
|
||||
elif self.provider == "gemini":
|
||||
base_params["top_p"] = top_p
|
||||
|
||||
return base_params
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 16384,
|
||||
top_p: float = 1.0,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
temperature: Temperature (0.0-2.0)
|
||||
max_tokens: Maximum number of tokens
|
||||
top_p: Top-p sampling (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||
|
||||
if self.provider == "openai":
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**params
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**params
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise e
|
||||
|
||||
elif self.provider == "gemini":
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
config = genai.GenerationConfig(
|
||||
temperature=params["temperature"],
|
||||
max_output_tokens=params["max_tokens"],
|
||||
top_p=params.get("top_p", 1.0)
|
||||
)
|
||||
response = self.client.generate_content(prompt, generation_config=config)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini API error: {e}")
|
||||
raise e
|
||||
|
||||
def generate_with_images(
|
||||
self,
|
||||
prompt: str,
|
||||
images_b64: List[str],
|
||||
batch_size: int = 3,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 16384,
|
||||
top_p: float = 1.0,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Generate with multiple images by batching
|
||||
|
||||
Args:
|
||||
prompt: Base instruction prompt
|
||||
images_b64: List of base64 encoded images
|
||||
batch_size: Number of images per batch
|
||||
temperature: Temperature (0.0-2.0)
|
||||
max_tokens: Maximum number of tokens
|
||||
top_p: Top-p sampling (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Final generated text
|
||||
"""
|
||||
if not images_b64:
|
||||
logger.warning("No images provided, falling back to text-only generation")
|
||||
return self.generate(prompt, temperature, max_tokens, top_p, **kwargs)
|
||||
|
||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||
total_batches = (len(images_b64) + batch_size - 1) // batch_size
|
||||
|
||||
if self.provider == "openai":
|
||||
return self._generate_with_images_openai(
|
||||
prompt, images_b64, batch_size, total_batches, params
|
||||
)
|
||||
elif self.provider == "anthropic":
|
||||
return self._generate_with_images_anthropic(
|
||||
prompt, images_b64, batch_size, total_batches, params
|
||||
)
|
||||
elif self.provider == "gemini":
|
||||
return self._generate_with_images_gemini(
|
||||
prompt, images_b64, batch_size, total_batches, params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _generate_with_images_openai(
|
||||
self,
|
||||
prompt: str,
|
||||
images_b64: List[str],
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
params: Dict[str, Any]
|
||||
) -> str:
|
||||
"""OpenAI implementation for batched image generation"""
|
||||
messages = []
|
||||
|
||||
for batch_idx in range(total_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
||||
batch_images = images_b64[start_idx:end_idx]
|
||||
|
||||
# Build content for this batch
|
||||
content = []
|
||||
|
||||
if batch_idx == 0:
|
||||
# First batch: include the main instruction
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": f"""{prompt}
|
||||
|
||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
||||
|
||||
This is batch {batch_idx + 1}/{total_batches}."""
|
||||
})
|
||||
else:
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt."
|
||||
})
|
||||
|
||||
# Add images
|
||||
for img_b64 in batch_images:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img_b64}"
|
||||
}
|
||||
})
|
||||
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
# Get acknowledgment (except for last batch)
|
||||
if batch_idx < total_batches - 1:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
assistant_msg = response.choices[0].message.content
|
||||
messages.append({"role": "assistant", "content": assistant_msg})
|
||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
||||
raise e
|
||||
|
||||
# Send final prompt
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "ALL IMAGES SENT. Please provide your evaluation now."
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting final evaluation: {e}")
|
||||
raise e
|
||||
|
||||
def _generate_with_images_anthropic(
|
||||
self,
|
||||
prompt: str,
|
||||
images_b64: List[str],
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
params: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Anthropic implementation for batched image generation"""
|
||||
messages = []
|
||||
|
||||
for batch_idx in range(total_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
||||
batch_images = images_b64[start_idx:end_idx]
|
||||
|
||||
# Build content for this batch
|
||||
content = []
|
||||
|
||||
if batch_idx == 0:
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": f"""{prompt}
|
||||
|
||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
||||
|
||||
This is batch {batch_idx + 1}/{total_batches}."""
|
||||
})
|
||||
else:
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt."
|
||||
})
|
||||
|
||||
# Add images
|
||||
for img_b64 in batch_images:
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": img_b64
|
||||
}
|
||||
})
|
||||
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
# Get acknowledgment (except for last batch)
|
||||
if batch_idx < total_batches - 1:
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
assistant_msg = response.content[0].text
|
||||
messages.append({"role": "assistant", "content": assistant_msg})
|
||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
||||
raise e
|
||||
|
||||
# Send final prompt
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "ALL IMAGES SENT. Please provide your evaluation now."
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting final evaluation: {e}")
|
||||
raise e
|
||||
|
||||
def _generate_with_images_gemini(
|
||||
self,
|
||||
prompt: str,
|
||||
images_b64: List[str],
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
params: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Gemini implementation for batched image generation"""
|
||||
import google.generativeai as genai
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
config = genai.GenerationConfig(
|
||||
temperature=params["temperature"],
|
||||
max_output_tokens=params["max_tokens"],
|
||||
top_p=params.get("top_p", 1.0)
|
||||
)
|
||||
|
||||
chat = self.client.start_chat()
|
||||
|
||||
for batch_idx in range(total_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
||||
batch_images = images_b64[start_idx:end_idx]
|
||||
|
||||
# Build content for this batch
|
||||
content_parts = []
|
||||
|
||||
if batch_idx == 0:
|
||||
content_parts.append(f"""{prompt}
|
||||
|
||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
||||
|
||||
This is batch {batch_idx + 1}/{total_batches}.""")
|
||||
else:
|
||||
content_parts.append(f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt.")
|
||||
|
||||
# Add images
|
||||
for img_b64 in batch_images:
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
content_parts.append(img)
|
||||
|
||||
# Get acknowledgment (except for last batch)
|
||||
if batch_idx < total_batches - 1:
|
||||
try:
|
||||
response = chat.send_message(content_parts, generation_config=config)
|
||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
||||
raise e
|
||||
|
||||
# Send final prompt
|
||||
try:
|
||||
response = chat.send_message(
|
||||
"ALL IMAGES SENT. Please provide your evaluation now.",
|
||||
generation_config=config
|
||||
)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting final evaluation: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def _load_screenshots_from_dir(result_dir: str) -> List[str]:
|
||||
"""
|
||||
Load all step screenshots from result directory and convert to base64
|
||||
|
||||
Args:
|
||||
result_dir: Path to result directory containing step_*.png files
|
||||
|
||||
Returns:
|
||||
List of base64 encoded screenshot strings
|
||||
"""
|
||||
screenshots = []
|
||||
|
||||
# Find all step screenshot files (e.g., step_1_20240101@120000.png)
|
||||
pattern = os.path.join(result_dir, "step_*.png")
|
||||
screenshot_files = sorted(glob.glob(pattern))
|
||||
|
||||
if not screenshot_files:
|
||||
logger.warning(f"No screenshot files found in {result_dir}")
|
||||
return screenshots
|
||||
|
||||
for filepath in screenshot_files:
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
img_data = f.read()
|
||||
img_b64 = base64.b64encode(img_data).decode('utf-8')
|
||||
screenshots.append(img_b64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading screenshot {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(screenshots)} screenshots from {result_dir}")
|
||||
return screenshots
|
||||
|
||||
|
||||
def vllm_eval(result_state, **options) -> float:
|
||||
"""
|
||||
Evaluate task completion using vision-language model
|
||||
|
||||
Args:
|
||||
result_state: Current state description
|
||||
**options: Additional options including:
|
||||
- result_dir: Path to result directory containing step screenshots (recommended)
|
||||
- screenshots: List of base64 encoded screenshots (deprecated, use result_dir instead)
|
||||
- instruction: Task instruction
|
||||
- eval_model: Model name to use
|
||||
- batch_size: Number of images per batch (default: 3)
|
||||
- temperature: Temperature parameter
|
||||
- max_tokens: Maximum tokens
|
||||
- top_p: Top-p parameter
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
# Try to load screenshots from result_dir if provided
|
||||
result_dir = options.get("result_dir", None)
|
||||
screenshots = options.get("screenshots", [])
|
||||
|
||||
if result_dir and not screenshots:
|
||||
screenshots = _load_screenshots_from_dir(result_dir)
|
||||
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
||||
elif screenshots:
|
||||
logger.info(f"Using {len(screenshots)} screenshots from options")
|
||||
|
||||
instruction = options.get("instruction", "")
|
||||
eval_model = options.get("eval_model", "gpt-4-vision-preview")
|
||||
batch_size = options.get("batch_size", 3)
|
||||
|
||||
params = {
|
||||
"temperature": options.get("temperature", 0.7),
|
||||
"max_tokens": options.get("max_tokens", 16384),
|
||||
"top_p": options.get("top_p", 1.0)
|
||||
}
|
||||
|
||||
llm = UnifiedLLM(eval_model)
|
||||
|
||||
prompt = f"""You are an expert evaluator for desktop environment tasks.
|
||||
|
||||
Task Instruction: {instruction}
|
||||
|
||||
I will provide you with screenshot(s) showing the current state of the desktop environment. Based on the instruction and screenshots, provide a concise evaluation score from 0.0 to 1.0, where:
|
||||
- 1.0 means the task is perfectly completed
|
||||
- 0.0 means the task is not completed at all
|
||||
- Values in between represent partial completion
|
||||
|
||||
Please return your response in the format: "Score: X.X" followed by a brief explanation."""
|
||||
|
||||
try:
|
||||
result = llm.generate_with_images(
|
||||
prompt=prompt,
|
||||
images_b64=screenshots,
|
||||
batch_size=batch_size,
|
||||
**params
|
||||
)
|
||||
|
||||
# Parse score from result
|
||||
score = _parse_score(result)
|
||||
logger.info(f"Evaluation result: {result}")
|
||||
logger.info(f"Parsed score: {score}")
|
||||
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.error(f"Error during evaluation: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _parse_score(text: str) -> float:
|
||||
"""Parse score from model response"""
|
||||
import re
|
||||
|
||||
# Look for "Score: X.X" pattern
|
||||
match = re.search(r'[Ss]core:\s*([0-9]*\.?[0-9]+)', text)
|
||||
if match:
|
||||
try:
|
||||
score = float(match.group(1))
|
||||
# Clamp to [0.0, 1.0]
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
logger.warning(f"Could not parse score from: {match.group(1)}")
|
||||
|
||||
logger.warning(f"No valid score found in response: {text[:200]}")
|
||||
return 0.0
|
||||
Reference in New Issue
Block a user