feat: vllm_eval 关键帧采样 + Gemini OpenAI 代理支持
- vllm_eval.py: 新增 _sample_key_frames 关键帧采样函数 - vllm_eval.py: 当截图超过 max_eval_images 时均匀采样 - vllm_eval.py: Gemini 模型支持通过 OpenAI 兼容代理调用 - test_single.json: 更新测试任务配置
This commit is contained in:
@@ -73,6 +73,12 @@ class UnifiedLLM:
|
|||||||
elif model.startswith("claude"):
|
elif model.startswith("claude"):
|
||||||
self.provider = "anthropic"
|
self.provider = "anthropic"
|
||||||
elif model.startswith("gemini"):
|
elif model.startswith("gemini"):
|
||||||
|
# If OPENAI_API_KEY is set but GOOGLE_API_KEY is not,
|
||||||
|
# use OpenAI-compatible proxy for Gemini models
|
||||||
|
if os.getenv("OPENAI_API_KEY") and not os.getenv("GOOGLE_API_KEY"):
|
||||||
|
self.provider = "openai"
|
||||||
|
logger.info(f"Using OpenAI-compatible proxy for Gemini model: {model}")
|
||||||
|
else:
|
||||||
self.provider = "gemini"
|
self.provider = "gemini"
|
||||||
else:
|
else:
|
||||||
self.provider = "unknown"
|
self.provider = "unknown"
|
||||||
@@ -287,15 +293,52 @@ class UnifiedLLM:
|
|||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
|
|
||||||
|
|
||||||
def _load_screenshots_from_dir(result_dir: str, compress: bool = True, max_size: int = 800, quality: int = 85) -> tuple:
|
def _sample_key_frames(items: list, max_count: int) -> list:
|
||||||
"""
|
"""
|
||||||
Load all step screenshots from result directory and convert to base64
|
Uniformly sample key frames while always keeping the first and last items.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of items to sample from
|
||||||
|
max_count: Maximum number of items to keep (must be >= 2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sampled indices (sorted)
|
||||||
|
"""
|
||||||
|
n = len(items)
|
||||||
|
if n <= max_count:
|
||||||
|
return list(range(n))
|
||||||
|
|
||||||
|
# Always keep first and last
|
||||||
|
if max_count < 2:
|
||||||
|
max_count = 2
|
||||||
|
|
||||||
|
indices = [0] # first frame
|
||||||
|
# Uniformly sample (max_count - 2) frames from the middle
|
||||||
|
middle_count = max_count - 2
|
||||||
|
if middle_count > 0:
|
||||||
|
step = (n - 2) / (middle_count + 1)
|
||||||
|
for i in range(1, middle_count + 1):
|
||||||
|
idx = int(round(i * step))
|
||||||
|
indices.append(idx)
|
||||||
|
indices.append(n - 1) # last frame
|
||||||
|
|
||||||
|
# Deduplicate and sort
|
||||||
|
indices = sorted(set(indices))
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
|
def _load_screenshots_from_dir(result_dir: str, compress: bool = True, max_size: int = 800, quality: int = 85, max_images: int = 0) -> tuple:
|
||||||
|
"""
|
||||||
|
Load step screenshots from result directory and convert to base64.
|
||||||
|
When max_images > 0 and there are more screenshots than max_images,
|
||||||
|
uniformly sample key frames (always keeping first and last).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_dir: Path to result directory containing step_*.png files
|
result_dir: Path to result directory containing step_*.png files
|
||||||
compress: Whether to compress images (default: True)
|
compress: Whether to compress images (default: True)
|
||||||
max_size: Maximum dimension for compression (default: 800)
|
max_size: Maximum dimension for compression (default: 800)
|
||||||
quality: JPEG quality for compression (default: 85)
|
quality: JPEG quality for compression (default: 85)
|
||||||
|
max_images: Maximum number of screenshots to load (0 = no limit)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (list of base64 encoded screenshot strings, list of short filenames like 'step_1', 'step_2', ...)
|
Tuple of (list of base64 encoded screenshot strings, list of short filenames like 'step_1', 'step_2', ...)
|
||||||
@@ -311,6 +354,16 @@ def _load_screenshots_from_dir(result_dir: str, compress: bool = True, max_size:
|
|||||||
logger.warning(f"No screenshot files found in {result_dir}")
|
logger.warning(f"No screenshot files found in {result_dir}")
|
||||||
return screenshots, filenames
|
return screenshots, filenames
|
||||||
|
|
||||||
|
# Key frame sampling: if max_images > 0 and we have more files than allowed,
|
||||||
|
# keep first + last + uniformly sampled middle frames
|
||||||
|
total_files = len(screenshot_files)
|
||||||
|
if max_images > 0 and total_files > max_images:
|
||||||
|
sampled_indices = _sample_key_frames(screenshot_files, max_images)
|
||||||
|
screenshot_files_sampled = [screenshot_files[i] for i in sampled_indices]
|
||||||
|
logger.info(f"Key frame sampling: {total_files} screenshots -> {len(screenshot_files_sampled)} "
|
||||||
|
f"(max_images={max_images}, kept indices: {sampled_indices})")
|
||||||
|
screenshot_files = screenshot_files_sampled
|
||||||
|
|
||||||
import re as _re
|
import re as _re
|
||||||
for filepath in screenshot_files:
|
for filepath in screenshot_files:
|
||||||
try:
|
try:
|
||||||
@@ -349,6 +402,8 @@ def vllm_eval(result_state, **options) -> float:
|
|||||||
- compress_images: Whether to compress images (default: True)
|
- compress_images: Whether to compress images (default: True)
|
||||||
- max_image_size: Maximum image dimension for compression (default: 800)
|
- max_image_size: Maximum image dimension for compression (default: 800)
|
||||||
- image_quality: JPEG quality for compression (default: 85)
|
- image_quality: JPEG quality for compression (default: 85)
|
||||||
|
- max_eval_images: Max screenshots for evaluation (0 = no limit, default: 10).
|
||||||
|
When exceeded, keeps first + last + uniformly sampled middle frames.
|
||||||
- temperature: Temperature parameter
|
- temperature: Temperature parameter
|
||||||
- max_tokens: Maximum tokens
|
- max_tokens: Maximum tokens
|
||||||
- top_p: Top-p parameter
|
- top_p: Top-p parameter
|
||||||
@@ -364,6 +419,7 @@ def vllm_eval(result_state, **options) -> float:
|
|||||||
compress_images = options.get("compress_images", True)
|
compress_images = options.get("compress_images", True)
|
||||||
max_image_size = options.get("max_image_size", 800)
|
max_image_size = options.get("max_image_size", 800)
|
||||||
image_quality = options.get("image_quality", 85)
|
image_quality = options.get("image_quality", 85)
|
||||||
|
max_eval_images = options.get("max_eval_images", 10)
|
||||||
|
|
||||||
screenshot_filenames = [] # Short names like 'step_1', 'step_2', ...
|
screenshot_filenames = [] # Short names like 'step_1', 'step_2', ...
|
||||||
|
|
||||||
@@ -372,7 +428,8 @@ def vllm_eval(result_state, **options) -> float:
|
|||||||
result_dir,
|
result_dir,
|
||||||
compress=compress_images,
|
compress=compress_images,
|
||||||
max_size=max_image_size,
|
max_size=max_image_size,
|
||||||
quality=image_quality
|
quality=image_quality,
|
||||||
|
max_images=max_eval_images
|
||||||
)
|
)
|
||||||
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
logger.info(f"Loaded {len(screenshots)} screenshots from result_dir: {result_dir}")
|
||||||
elif screenshots:
|
elif screenshots:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
{"avogadro": ["building-organic-molecules_task1"],
|
{"avogadro": ["building-organic-molecules_task1"],
|
||||||
"jade": ["MDIJade6.5使用手册_task10"]}
|
"jade": ["jade-guide-example_task12"]}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user