fix(vllm_eval): add image compression to prevent 413 error with large max_steps
This commit is contained in:
@@ -4,13 +4,69 @@ from dotenv import load_dotenv
|
|||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
import glob
|
import glob
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.vllm_eval")
|
logger = logging.getLogger("desktopenv.vllm_eval")
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def _compress_image(img_b64: str, max_size: int = 800, quality: int = 85) -> str:
|
||||||
|
"""
|
||||||
|
Compress base64 encoded image to reduce size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_b64: Base64 encoded image string
|
||||||
|
max_size: Maximum dimension (width or height) in pixels
|
||||||
|
quality: JPEG quality (1-100), lower means smaller file size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed base64 encoded image string
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Decode base64 to image
|
||||||
|
img_data = base64.b64decode(img_b64)
|
||||||
|
img = Image.open(BytesIO(img_data))
|
||||||
|
|
||||||
|
# Convert to RGB if necessary (for PNG with transparency)
|
||||||
|
if img.mode in ('RGBA', 'LA', 'P'):
|
||||||
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
|
if img.mode == 'P':
|
||||||
|
img = img.convert('RGBA')
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||||
|
img = background
|
||||||
|
|
||||||
|
# Resize if image is too large
|
||||||
|
original_size = img.size
|
||||||
|
if max(img.size) > max_size:
|
||||||
|
ratio = max_size / max(img.size)
|
||||||
|
new_size = tuple(int(dim * ratio) for dim in img.size)
|
||||||
|
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
logger.info(f"Resized image from {original_size} to {new_size}")
|
||||||
|
|
||||||
|
# Compress to JPEG
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format='JPEG', quality=quality, optimize=True)
|
||||||
|
compressed_data = buffer.getvalue()
|
||||||
|
|
||||||
|
# Encode back to base64
|
||||||
|
compressed_b64 = base64.b64encode(compressed_data).decode('utf-8')
|
||||||
|
|
||||||
|
# Log compression ratio
|
||||||
|
original_size_kb = len(img_b64) * 3 / 4 / 1024 # base64 to bytes to KB
|
||||||
|
compressed_size_kb = len(compressed_b64) * 3 / 4 / 1024
|
||||||
|
compression_ratio = (1 - compressed_size_kb / original_size_kb) * 100
|
||||||
|
logger.info(f"Compressed image: {original_size_kb:.1f}KB -> {compressed_size_kb:.1f}KB ({compression_ratio:.1f}% reduction)")
|
||||||
|
|
||||||
|
return compressed_b64
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to compress image: {e}, using original")
|
||||||
|
return img_b64
|
||||||
|
|
||||||
|
|
||||||
class UnifiedLLM:
|
class UnifiedLLM:
|
||||||
|
|
||||||
def __init__(self, model: str):
|
def __init__(self, model: str):
|
||||||
if model.startswith("gpt"):
|
if model.startswith("gpt"):
|
||||||
self.provider = "openai"
|
self.provider = "openai"
|
||||||
@@ -20,43 +76,43 @@ class UnifiedLLM:
|
|||||||
self.provider = "gemini"
|
self.provider = "gemini"
|
||||||
else:
|
else:
|
||||||
self.provider = "unknown"
|
self.provider = "unknown"
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.client = self._init_client()
|
self.client = self._init_client()
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
"""Initialize client"""
|
"""Initialize client"""
|
||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
base_url=os.getenv("OPENAI_BASE_URL"),
|
base_url=os.getenv("OPENAI_BASE_URL"),
|
||||||
api_key=os.getenv("OPENAI_API_KEY")
|
api_key=os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.provider == "anthropic":
|
elif self.provider == "anthropic":
|
||||||
from anthropic import Anthropic
|
from anthropic import Anthropic
|
||||||
return Anthropic(
|
return Anthropic(
|
||||||
base_url=os.getenv("ANTHROPIC_BASE_URL"),
|
base_url=os.getenv("ANTHROPIC_BASE_URL"),
|
||||||
api_key=os.getenv("ANTHROPIC_API_KEY")
|
api_key=os.getenv("ANTHROPIC_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.provider == "gemini":
|
elif self.provider == "gemini":
|
||||||
logger.warning("Using Google Gemini model, make sure your internet connection is working.")
|
logger.warning("Using Google Gemini model, make sure your internet connection is working.")
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||||
return genai.GenerativeModel(self.model)
|
return genai.GenerativeModel(self.model)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unsupported LLM provider for model: {self.model}")
|
logger.error(f"Unsupported LLM provider for model: {self.model}")
|
||||||
raise ValueError(f"Unsupported LLM provider for model: {self.model}")
|
raise ValueError(f"Unsupported LLM provider for model: {self.model}")
|
||||||
|
|
||||||
def _get_supported_params(self, temperature: float, max_tokens: int, top_p: float) -> Dict[str, Any]:
|
def _get_supported_params(self, temperature: float, max_tokens: int, top_p: float) -> Dict[str, Any]:
|
||||||
"""Get supported parameters for each provider"""
|
"""Get supported parameters for each provider"""
|
||||||
base_params = {
|
base_params = {
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens
|
"max_tokens": max_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
# GPT-5.2 and newer models may not support top_p
|
# GPT-5.2 and newer models may not support top_p
|
||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
# Only add top_p for older models
|
# Only add top_p for older models
|
||||||
@@ -66,9 +122,9 @@ class UnifiedLLM:
|
|||||||
base_params["top_p"] = top_p
|
base_params["top_p"] = top_p
|
||||||
elif self.provider == "gemini":
|
elif self.provider == "gemini":
|
||||||
base_params["top_p"] = top_p
|
base_params["top_p"] = top_p
|
||||||
|
|
||||||
return base_params
|
return base_params
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -83,12 +139,12 @@ class UnifiedLLM:
|
|||||||
temperature: Temperature (0.0-2.0)
|
temperature: Temperature (0.0-2.0)
|
||||||
max_tokens: Maximum number of tokens
|
max_tokens: Maximum number of tokens
|
||||||
top_p: Top-p sampling (0.0-1.0)
|
top_p: Top-p sampling (0.0-1.0)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated text
|
Generated text
|
||||||
"""
|
"""
|
||||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||||
|
|
||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
@@ -100,7 +156,7 @@ class UnifiedLLM:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}")
|
logger.error(f"OpenAI API error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif self.provider == "anthropic":
|
elif self.provider == "anthropic":
|
||||||
try:
|
try:
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
@@ -112,7 +168,7 @@ class UnifiedLLM:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Anthropic API error: {e}")
|
logger.error(f"Anthropic API error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif self.provider == "gemini":
|
elif self.provider == "gemini":
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
@@ -126,281 +182,120 @@ class UnifiedLLM:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini API error: {e}")
|
logger.error(f"Gemini API error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def generate_with_images(
|
def generate_with_images(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
images_b64: List[str],
|
images_b64: List[str],
|
||||||
batch_size: int = 3,
|
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 16384,
|
max_tokens: int = 16384,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate with multiple images by batching
|
Generate with multiple images in a single request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: Base instruction prompt
|
prompt: Instruction prompt
|
||||||
images_b64: List of base64 encoded images
|
images_b64: List of base64 encoded images
|
||||||
batch_size: Number of images per batch
|
|
||||||
temperature: Temperature (0.0-2.0)
|
temperature: Temperature (0.0-2.0)
|
||||||
max_tokens: Maximum number of tokens
|
max_tokens: Maximum number of tokens
|
||||||
top_p: Top-p sampling (0.0-1.0)
|
top_p: Top-p sampling (0.0-1.0)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final generated text
|
Generated text
|
||||||
"""
|
"""
|
||||||
if not images_b64:
|
if not images_b64:
|
||||||
logger.warning("No images provided, falling back to text-only generation")
|
logger.warning("No images provided, falling back to text-only generation")
|
||||||
return self.generate(prompt, temperature, max_tokens, top_p, **kwargs)
|
return self.generate(prompt, temperature, max_tokens, top_p, **kwargs)
|
||||||
|
|
||||||
params = self._get_supported_params(temperature, max_tokens, top_p)
|
params = self._get_supported_params(temperature, max_tokens, top_p)
|
||||||
total_batches = (len(images_b64) + batch_size - 1) // batch_size
|
|
||||||
|
|
||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
return self._generate_with_images_openai(
|
# Build content with text and all images
|
||||||
prompt, images_b64, batch_size, total_batches, params
|
content = [{"type": "text", "text": prompt}]
|
||||||
)
|
|
||||||
elif self.provider == "anthropic":
|
|
||||||
return self._generate_with_images_anthropic(
|
|
||||||
prompt, images_b64, batch_size, total_batches, params
|
|
||||||
)
|
|
||||||
elif self.provider == "gemini":
|
|
||||||
return self._generate_with_images_gemini(
|
|
||||||
prompt, images_b64, batch_size, total_batches, params
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
||||||
|
|
||||||
def _generate_with_images_openai(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
images_b64: List[str],
|
|
||||||
batch_size: int,
|
|
||||||
total_batches: int,
|
|
||||||
params: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""OpenAI implementation for batched image generation"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for batch_idx in range(total_batches):
|
|
||||||
start_idx = batch_idx * batch_size
|
|
||||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
|
||||||
batch_images = images_b64[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Build content for this batch
|
|
||||||
content = []
|
|
||||||
|
|
||||||
if batch_idx == 0:
|
|
||||||
# First batch: include the main instruction
|
|
||||||
content.append({
|
|
||||||
"type": "text",
|
|
||||||
"text": f"""{prompt}
|
|
||||||
|
|
||||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
for img_b64 in images_b64:
|
||||||
|
|
||||||
This is batch {batch_idx + 1}/{total_batches}."""
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
content.append({
|
|
||||||
"type": "text",
|
|
||||||
"text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt."
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add images
|
|
||||||
for img_b64 in batch_images:
|
|
||||||
content.append({
|
content.append({
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{img_b64}"
|
"url": f"data:image/jpeg;base64,{img_b64}"
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
messages.append({"role": "user", "content": content})
|
|
||||||
|
|
||||||
# Get acknowledgment (except for last batch)
|
|
||||||
if batch_idx < total_batches - 1:
|
|
||||||
try:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
**params
|
|
||||||
)
|
|
||||||
assistant_msg = response.choices[0].message.content
|
|
||||||
messages.append({"role": "assistant", "content": assistant_msg})
|
|
||||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Send final prompt
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": "ALL IMAGES SENT. Please provide your evaluation now."
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
**params
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting final evaluation: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _generate_with_images_anthropic(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
images_b64: List[str],
|
|
||||||
batch_size: int,
|
|
||||||
total_batches: int,
|
|
||||||
params: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Anthropic implementation for batched image generation"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for batch_idx in range(total_batches):
|
|
||||||
start_idx = batch_idx * batch_size
|
|
||||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
|
||||||
batch_images = images_b64[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Build content for this batch
|
|
||||||
content = []
|
|
||||||
|
|
||||||
if batch_idx == 0:
|
|
||||||
content.append({
|
|
||||||
"type": "text",
|
|
||||||
"text": f"""{prompt}
|
|
||||||
|
|
||||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": content}],
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
This is batch {batch_idx + 1}/{total_batches}."""
|
elif self.provider == "anthropic":
|
||||||
})
|
# Build content with text and all images
|
||||||
else:
|
content = [{"type": "text", "text": prompt}]
|
||||||
content.append({
|
|
||||||
"type": "text",
|
for img_b64 in images_b64:
|
||||||
"text": f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt."
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add images
|
|
||||||
for img_b64 in batch_images:
|
|
||||||
content.append({
|
content.append({
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
"media_type": "image/png",
|
"media_type": "image/jpeg",
|
||||||
"data": img_b64
|
"data": img_b64
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
messages.append({"role": "user", "content": content})
|
try:
|
||||||
|
response = self.client.messages.create(
|
||||||
# Get acknowledgment (except for last batch)
|
model=self.model,
|
||||||
if batch_idx < total_batches - 1:
|
messages=[{"role": "user", "content": content}],
|
||||||
try:
|
**params
|
||||||
response = self.client.messages.create(
|
)
|
||||||
model=self.model,
|
return response.content[0].text
|
||||||
messages=messages,
|
except Exception as e:
|
||||||
**params
|
logger.error(f"Anthropic API error: {e}")
|
||||||
)
|
raise e
|
||||||
assistant_msg = response.content[0].text
|
|
||||||
messages.append({"role": "assistant", "content": assistant_msg})
|
elif self.provider == "gemini":
|
||||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
import google.generativeai as genai
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
config = genai.GenerationConfig(
|
||||||
raise e
|
temperature=params["temperature"],
|
||||||
|
max_output_tokens=params["max_tokens"],
|
||||||
# Send final prompt
|
top_p=params.get("top_p", 1.0)
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": "ALL IMAGES SENT. Please provide your evaluation now."
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.client.messages.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
**params
|
|
||||||
)
|
)
|
||||||
return response.content[0].text
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting final evaluation: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _generate_with_images_gemini(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
images_b64: List[str],
|
|
||||||
batch_size: int,
|
|
||||||
total_batches: int,
|
|
||||||
params: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Gemini implementation for batched image generation"""
|
|
||||||
import google.generativeai as genai
|
|
||||||
from PIL import Image
|
|
||||||
import io
|
|
||||||
|
|
||||||
config = genai.GenerationConfig(
|
|
||||||
temperature=params["temperature"],
|
|
||||||
max_output_tokens=params["max_tokens"],
|
|
||||||
top_p=params.get("top_p", 1.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat = self.client.start_chat()
|
|
||||||
|
|
||||||
for batch_idx in range(total_batches):
|
|
||||||
start_idx = batch_idx * batch_size
|
|
||||||
end_idx = min(start_idx + batch_size, len(images_b64))
|
|
||||||
batch_images = images_b64[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Build content for this batch
|
|
||||||
content_parts = []
|
|
||||||
|
|
||||||
if batch_idx == 0:
|
|
||||||
content_parts.append(f"""{prompt}
|
|
||||||
|
|
||||||
I will send you images in {total_batches} batch(es). Please acknowledge each batch but DO NOT provide your final evaluation until I explicitly say "ALL IMAGES SENT. Please provide your evaluation now."
|
# Build content parts
|
||||||
|
content_parts = [prompt]
|
||||||
|
|
||||||
This is batch {batch_idx + 1}/{total_batches}.""")
|
for img_b64 in images_b64:
|
||||||
else:
|
|
||||||
content_parts.append(f"This is batch {batch_idx + 1}/{total_batches}. Please acknowledge receipt.")
|
|
||||||
|
|
||||||
# Add images
|
|
||||||
for img_b64 in batch_images:
|
|
||||||
img_data = base64.b64decode(img_b64)
|
img_data = base64.b64decode(img_b64)
|
||||||
img = Image.open(io.BytesIO(img_data))
|
img = Image.open(BytesIO(img_data))
|
||||||
content_parts.append(img)
|
content_parts.append(img)
|
||||||
|
|
||||||
# Get acknowledgment (except for last batch)
|
try:
|
||||||
if batch_idx < total_batches - 1:
|
response = self.client.generate_content(content_parts, generation_config=config)
|
||||||
try:
|
return response.text
|
||||||
response = chat.send_message(content_parts, generation_config=config)
|
except Exception as e:
|
||||||
logger.info(f"Batch {batch_idx + 1}/{total_batches} acknowledged")
|
logger.error(f"Gemini API error: {e}")
|
||||||
except Exception as e:
|
raise e
|
||||||
logger.error(f"Error sending batch {batch_idx + 1}: {e}")
|
|
||||||
raise e
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
# Send final prompt
|
|
||||||
try:
|
|
||||||
response = chat.send_message(
|
|
||||||
"ALL IMAGES SENT. Please provide your evaluation now.",
|
|
||||||
generation_config=config
|
|
||||||
)
|
|
||||||
return response.text
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting final evaluation: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def _load_screenshots_from_dir(result_dir: str) -> List[str]:
|
def _load_screenshots_from_dir(result_dir: str, compress: bool = True, max_size: int = 800, quality: int = 85) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Load all step screenshots from result directory and convert to base64
|
Load all step screenshots from result directory and convert to base64
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_dir: Path to result directory containing step_*.png files
|
result_dir: Path to result directory containing step_*.png files
|
||||||
|
compress: Whether to compress images (default: True)
|
||||||
|
max_size: Maximum dimension for compression (default: 800)
|
||||||
|
quality: JPEG quality for compression (default: 85)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of base64 encoded screenshot strings
|
List of base64 encoded screenshot strings
|
||||||
@@ -420,6 +315,11 @@ def _load_screenshots_from_dir(result_dir: str) -> List[str]:
|
|||||||
with open(filepath, "rb") as f:
|
with open(filepath, "rb") as f:
|
||||||
img_data = f.read()
|
img_data = f.read()
|
||||||
img_b64 = base64.b64encode(img_data).decode('utf-8')
|
img_b64 = base64.b64encode(img_data).decode('utf-8')
|
||||||
|
|
||||||
|
# Compress if enabled
|
||||||
|
if compress:
|
||||||
|
img_b64 = _compress_image(img_b64, max_size=max_size, quality=quality)
|
||||||
|
|
||||||
screenshots.append(img_b64)
|
screenshots.append(img_b64)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading screenshot {filepath}: {e}")
|
logger.error(f"Error loading screenshot {filepath}: {e}")
|
||||||
@@ -436,10 +336,12 @@ def vllm_eval(result_state, **options) -> float:
|
|||||||
result_state: Current state description
|
result_state: Current state description
|
||||||
**options: Additional options including:
|
**options: Additional options including:
|
||||||
- result_dir: Path to result directory containing step screenshots (recommended)
|
- result_dir: Path to result directory containing step screenshots (recommended)
|
||||||
- screenshots: List of base64 encoded screenshots (deprecated, use result_dir instead)
|
- screenshots: List of base64 encoded screenshots (deprecated, use result_dir instead)
|
||||||
- instruction: Task instruction
|
- instruction: Task instruction
|
||||||
- eval_model: Model name to use
|
- eval_model: Model name to use
|
||||||
- batch_size: Number of images per batch (default: 3)
|
- compress_images: Whether to compress images (default: True)
|
||||||
|
- max_image_size: Maximum image dimension for compression (default: 800)
|
||||||
|
- image_quality: JPEG quality for compression (default: 85)
|
||||||
- temperature: Temperature parameter
|
- temperature: Temperature parameter
|
||||||
- max_tokens: Maximum tokens
|
- max_tokens: Maximum tokens
|
||||||
- top_p: Top-p parameter
|
- top_p: Top-p parameter
|
||||||
@@ -451,24 +353,37 @@ def vllm_eval(result_state, **options) -> float:
|
|||||||
result_dir = options.get("result_dir", None)
|
result_dir = options.get("result_dir", None)
|
||||||
screenshots = options.get("screenshots", [])
|
screenshots = options.get("screenshots", [])
|
||||||
|
|
||||||
|
# Image compression options
|
||||||
|
compress_images = options.get("compress_images", True)
|
||||||
|
max_image_size = options.get("max_image_size", 800)
|
||||||
|
image_quality = options.get("image_quality", 85)
|
||||||
|
|
||||||
if result_dir and not screenshots:
|
if result_dir and not screenshots:
|
||||||
screenshots = _load_screenshots_from_dir(result_dir)
|
screenshots = _load_screenshots_from_dir(
|
||||||
|
result_dir,
|
||||||
|
compress=compress_images,
|
||||||
|
max_size=max_image_size,
|
||||||
|
quality=image_quality
|
||||||
|
)
|
||||||
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
||||||
elif screenshots:
|
elif screenshots:
|
||||||
logger.info(f"Using {len(screenshots)} screenshots from options")
|
logger.info(f"Using {len(screenshots)} screenshots from options")
|
||||||
|
# Compress screenshots if needed
|
||||||
|
if compress_images:
|
||||||
|
logger.info("Compressing provided screenshots...")
|
||||||
|
screenshots = [_compress_image(img, max_size=max_image_size, quality=image_quality) for img in screenshots]
|
||||||
|
|
||||||
instruction = options.get("instruction", "")
|
instruction = options.get("instruction", "")
|
||||||
eval_model = options.get("eval_model", "gpt-4-vision-preview")
|
eval_model = options.get("eval_model", "gpt-4-vision-preview")
|
||||||
batch_size = options.get("batch_size", 3)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"temperature": options.get("temperature", 0.7),
|
"temperature": options.get("temperature", 0.7),
|
||||||
"max_tokens": options.get("max_tokens", 16384),
|
"max_tokens": options.get("max_tokens", 16384),
|
||||||
"top_p": options.get("top_p", 1.0)
|
"top_p": options.get("top_p", 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = UnifiedLLM(eval_model)
|
llm = UnifiedLLM(eval_model)
|
||||||
|
|
||||||
prompt = f"""You are an expert evaluator for desktop environment tasks.
|
prompt = f"""You are an expert evaluator for desktop environment tasks.
|
||||||
|
|
||||||
Task Instruction: {instruction}
|
Task Instruction: {instruction}
|
||||||
@@ -495,27 +410,26 @@ Where:
|
|||||||
- "score": Integer from 0 to 10, where 10 means perfectly completed and 0 means not completed at all
|
- "score": Integer from 0 to 10, where 10 means perfectly completed and 0 means not completed at all
|
||||||
|
|
||||||
Remember: Return ONLY the JSON object, no additional text."""
|
Remember: Return ONLY the JSON object, no additional text."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = llm.generate_with_images(
|
result = llm.generate_with_images(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
images_b64=screenshots,
|
images_b64=screenshots,
|
||||||
batch_size=batch_size,
|
|
||||||
**params
|
**params
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse score from result
|
# Parse score from result
|
||||||
score = _parse_score(result)
|
score = _parse_score(result)
|
||||||
logger.info(f"Evaluation result: {result}")
|
logger.info(f"Evaluation result: {result}")
|
||||||
logger.info(f"Parsed score: {score}")
|
logger.info(f"Parsed score: {score}")
|
||||||
|
|
||||||
# Save raw result to file for reference
|
# Save raw result to file for reference
|
||||||
if result_dir:
|
if result_dir:
|
||||||
eval_output_path = os.path.join(result_dir, "vllm_evaluation_result.json")
|
eval_output_path = os.path.join(result_dir, "vllm_evaluation_result.json")
|
||||||
with open(eval_output_path, "w", encoding="utf-8") as f:
|
with open(eval_output_path, "w", encoding="utf-8") as f:
|
||||||
f.write(result)
|
f.write(result)
|
||||||
logger.info(f"Saved evaluation result to {eval_output_path}")
|
logger.info(f"Saved evaluation result to {eval_output_path}")
|
||||||
|
|
||||||
return score
|
return score
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during evaluation: {e}")
|
logger.error(f"Error during evaluation: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user