diff --git a/desktop_env/evaluators/metrics/vllm_eval.py b/desktop_env/evaluators/metrics/vllm_eval.py new file mode 100644 index 0000000..d7b971f --- /dev/null +++ b/desktop_env/evaluators/metrics/vllm_eval.py @@ -0,0 +1,517 @@ +import os +from typing import Optional, List, Dict, Any +from dotenv import load_dotenv +import logging +import base64 +import glob + +logger = logging.getLogger("desktopenv.vllm_eval") +load_dotenv() + + +class UnifiedLLM: + + def __init__(self, model: str): + if model.startswith("gpt"): + self.provider = "openai" + elif model.startswith("claude"): + self.provider = "anthropic" + elif model.startswith("gemini"): + self.provider = "gemini" + else: + self.provider = "unknown" + + self.model = model + self.client = self._init_client() + + def _init_client(self): + """Initialize client""" + if self.provider == "openai": + from openai import OpenAI + return OpenAI( + base_url=os.getenv("OPENAI_BASE_URL"), + api_key=os.getenv("OPENAI_API_KEY") + ) + + elif self.provider == "anthropic": + from anthropic import Anthropic + return Anthropic( + base_url=os.getenv("ANTHROPIC_BASE_URL"), + api_key=os.getenv("ANTHROPIC_API_KEY") + ) + + elif self.provider == "gemini": + logger.warning("Using Google Gemini model, make sure your internet connection is working.") + import google.generativeai as genai + genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + return genai.GenerativeModel(self.model) + + else: + logger.error(f"Unsupported LLM provider for model: {self.model}") + raise ValueError(f"Unsupported LLM provider for model: {self.model}") + + def _get_supported_params(self, temperature: float, max_tokens: int, top_p: float) -> Dict[str, Any]: + """Get supported parameters for each provider""" + base_params = { + "temperature": temperature, + "max_tokens": max_tokens + } + + # GPT-5.2 and newer models may not support top_p + if self.provider == "openai": + # Only add top_p for older models + if not self.model.startswith("gpt-5"): + base_params["top_p"] = top_p + elif self.provider == "anthropic": + base_params["top_p"] = top_p + elif self.provider == "gemini": + base_params["top_p"] = top_p + + return base_params + + def generate( + self, + prompt: str, + temperature: float = 0.7, + max_tokens: int = 16384, + top_p: float = 1.0, + **kwargs + ) -> str: + """ + Args: + prompt: Input prompt + temperature: Temperature (0.0-2.0) + max_tokens: Maximum number of tokens + top_p: Top-p sampling (0.0-1.0) + + Returns: + Generated text + """ + params = self._get_supported_params(temperature, max_tokens, top_p) + + if self.provider == "openai": + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + **params + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"OpenAI API error: {e}") + raise e + + elif self.provider == "anthropic": + try: + response = self.client.messages.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + **params + ) + return response.content[0].text + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise e + + elif self.provider == "gemini": + try: + import google.generativeai as genai + config = genai.GenerationConfig( + temperature=params["temperature"], + max_output_tokens=params["max_tokens"], + top_p=params.get("top_p", 1.0) + ) + response = self.client.generate_content(prompt, generation_config=config) + return response.text + except Exception as e: + logger.error(f"Gemini API error: {e}") + raise e + + def generate_with_images( + self, + prompt: str, + images_b64: List[str], + batch_size: int = 3, + temperature: float = 0.7, + max_tokens: int = 16384, + top_p: float = 1.0, + **kwargs + ) -> str: + """ + Generate with multiple images by batching + + Args: + prompt: Base instruction prompt + images_b64: List of base64 encoded images + batch_size: Number of images per batch + temperature: Temperature (0.0-2.0) + max_tokens: Maximum number of tokens + top_p: Top-p sampling (0.0-1.0) + + Returns: + Final generated text + """ + if not images_b64: + logger.warning("No images provided, falling back to text-only generation") + return self.generate(prompt, temperature, max_tokens, top_p, **kwargs) + + params = self._get_supported_params(temperature, max_tokens, top_p) + total_batches = (len(images_b64) + batch_size - 1) // batch_size + + if self.provider == "openai": + return self._generate_with_images_openai( + prompt, images_b64, batch_size, total_batches, params + ) + elif self.provider == "anthropic": + return self._generate_with_images_anthropic( + prompt, images_b64, batch_size, total_batches, params + ) + elif self.provider == "gemini": + return self._generate_with_images_gemini( + prompt, images_b64, batch_size, total_batches, params + ) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def _generate_with_images_openai( + self, + prompt: str, + images_b64: List[str], + batch_size: int, + total_batches: int, + params: Dict[str, Any] + ) -> str: + """OpenAI implementation for batched image generation""" + messages = [] + + for batch_idx in range(total_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(images_b64)) + batch_images = images_b64[start_idx:end_idx] + + # Build content for this batch + content = [] + + if batch_idx == 0: + # First batch: include the main instruction + content.append({ + "type": "text", + "text": f"""{prompt} + +I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now." + +This is batch {batch_idx + 1}/{total_batches}.""" + }) + else: + content.append({ + "type": "text", + "text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt." + }) + + # Add images + for img_b64 in batch_images: + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_b64}" + } + }) + + messages.append({"role": "user", "content": content}) + + # Get acknowledgment (except for last batch) + if batch_idx < total_batches - 1: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + **params + ) + assistant_msg = response.choices[0].message.content + messages.append({"role": "assistant", "content": assistant_msg}) + logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged") + except Exception as e: + logger.error(f"Error sending batch {batch_idx + 1}: {e}") + raise e + + # Send final prompt + messages.append({ + "role": "user", + "content": "ALL IMAGES SENT. Please provide your evaluation now." + }) + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + **params + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"Error getting final evaluation: {e}") + raise e + + def _generate_with_images_anthropic( + self, + prompt: str, + images_b64: List[str], + batch_size: int, + total_batches: int, + params: Dict[str, Any] + ) -> str: + """Anthropic implementation for batched image generation""" + messages = [] + + for batch_idx in range(total_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(images_b64)) + batch_images = images_b64[start_idx:end_idx] + + # Build content for this batch + content = [] + + if batch_idx == 0: + content.append({ + "type": "text", + "text": f"""{prompt} + +I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now." + +This is batch {batch_idx + 1}/{total_batches}.""" + }) + else: + content.append({ + "type": "text", + "text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt." + }) + + # Add images + for img_b64 in batch_images: + content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img_b64 + } + }) + + messages.append({"role": "user", "content": content}) + + # Get acknowledgment (except for last batch) + if batch_idx < total_batches - 1: + try: + response = self.client.messages.create( + model=self.model, + messages=messages, + **params + ) + assistant_msg = response.content[0].text + messages.append({"role": "assistant", "content": assistant_msg}) + logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged") + except Exception as e: + logger.error(f"Error sending batch {batch_idx + 1}: {e}") + raise e + + # Send final prompt + messages.append({ + "role": "user", + "content": "ALL IMAGES SENT. Please provide your evaluation now." + }) + + try: + response = self.client.messages.create( + model=self.model, + messages=messages, + **params + ) + return response.content[0].text + except Exception as e: + logger.error(f"Error getting final evaluation: {e}") + raise e + + def _generate_with_images_gemini( + self, + prompt: str, + images_b64: List[str], + batch_size: int, + total_batches: int, + params: Dict[str, Any] + ) -> str: + """Gemini implementation for batched image generation""" + import google.generativeai as genai + from PIL import Image + import io + + config = genai.GenerationConfig( + temperature=params["temperature"], + max_output_tokens=params["max_tokens"], + top_p=params.get("top_p", 1.0) + ) + + chat = self.client.start_chat() + + for batch_idx in range(total_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(images_b64)) + batch_images = images_b64[start_idx:end_idx] + + # Build content for this batch + content_parts = [] + + if batch_idx == 0: + content_parts.append(f"""{prompt} + +I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now." + +This is batch {batch_idx + 1}/{total_batches}.""") + else: + content_parts.append(f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt.") + + # Add images + for img_b64 in batch_images: + img_data = base64.b64decode(img_b64) + img = Image.open(io.BytesIO(img_data)) + content_parts.append(img) + + # Get acknowledgment (except for last batch) + if batch_idx < total_batches - 1: + try: + response = chat.send_message(content_parts, generation_config=config) + logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged") + except Exception as e: + logger.error(f"Error sending batch {batch_idx + 1}: {e}") + raise e + + # Send final prompt + try: + response = chat.send_message( + "ALL IMAGES SENT. Please provide your evaluation now.", + generation_config=config + ) + return response.text + except Exception as e: + logger.error(f"Error getting final evaluation: {e}") + raise e + + +def _load_screenshots_from_dir(result_dir: str) -> List[str]: + """ + Load all step screenshots from result directory and convert to base64 + + Args: + result_dir: Path to result directory containing step_*.png files + + Returns: + List of base64 encoded screenshot strings + """ + screenshots = [] + + # Find all step screenshot files (e.g., step_1_20240101@120000.png) + pattern = os.path.join(result_dir, "step_*.png") + screenshot_files = sorted(glob.glob(pattern)) + + if not screenshot_files: + logger.warning(f"No screenshot files found in {result_dir}") + return screenshots + + for filepath in screenshot_files: + try: + with open(filepath, "rb") as f: + img_data = f.read() + img_b64 = base64.b64encode(img_data).decode('utf-8') + screenshots.append(img_b64) + except Exception as e: + logger.error(f"Error loading screenshot {filepath}: {e}") + + logger.info(f"Loaded {len(screenshots)} screenshots from {result_dir}") + return screenshots + + +def vllm_eval(result_state, **options) -> float: + """ + Evaluate task completion using vision-language model + + Args: + result_state: Current state description + **options: Additional options including: + - result_dir: Path to result directory containing step screenshots (recommended) + - screenshots: List of base64 encoded screenshots (deprecated, use result_dir instead) + - instruction: Task instruction + - eval_model: Model name to use + - batch_size: Number of images per batch (default: 3) + - temperature: Temperature parameter + - max_tokens: Maximum tokens + - top_p: Top-p parameter + + Returns: + Score between 0.0 and 1.0 + """ + # Try to load screenshots from result_dir if provided + result_dir = options.get("result_dir", None) + screenshots = options.get("screenshots", []) + + if result_dir and not screenshots: + screenshots = _load_screenshots_from_dir(result_dir) + logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}") + elif screenshots: + logger.info(f"Using {len(screenshots)} screenshots from options") + + instruction = options.get("instruction", "") + eval_model = options.get("eval_model", "gpt-4-vision-preview") + batch_size = options.get("batch_size", 3) + + params = { + "temperature": options.get("temperature", 0.7), + "max_tokens": options.get("max_tokens", 16384), + "top_p": options.get("top_p", 1.0) + } + + llm = UnifiedLLM(eval_model) + + prompt = f"""You are an expert evaluator for desktop environment tasks. + +Task Instruction: {instruction} + +I will provide you with screenshot(s) showing the current state of the desktop environment. Based on the instruction and screenshots, provide a concise evaluation score from 0.0 to 1.0, where: +- 1.0 means the task is perfectly completed +- 0.0 means the task is not completed at all +- Values in between represent partial completion + +Please return your response in the format: "Score: X.X" followed by a brief explanation.""" + + try: + result = llm.generate_with_images( + prompt=prompt, + images_b64=screenshots, + batch_size=batch_size, + **params + ) + + # Parse score from result + score = _parse_score(result) + logger.info(f"Evaluation result: {result}") + logger.info(f"Parsed score: {score}") + + return score + except Exception as e: + logger.error(f"Error during evaluation: {e}") + return 0.0 + + +def _parse_score(text: str) -> float: + """Parse score from model response""" + import re + + # Look for "Score: X.X" pattern + match = re.search(r'[Ss]core:\s*([0-9]*\.?[0-9]+)', text) + if match: + try: + score = float(match.group(1)) + # Clamp to [0.0, 1.0] + return max(0.0, min(1.0, score)) + except ValueError: + logger.warning(f"Could not parse score from: {match.group(1)}") + + logger.warning(f"No valid score found in response: {text[:200]}") + return 0.0