@@ -4,21 +4,20 @@ import os
|
||||
|
||||
class GrounderClient(object):
|
||||
def __init__(self):
|
||||
# Proxy for hosting UI-TARS + UiElementPredictor
|
||||
# Could be replaced with a VLLM server and grounder (UI-TARS) specific processing
|
||||
# Or any other grounder
|
||||
# Proxy for hosting finetuned Qwen3VL + UiElementPredictor
|
||||
# Could be replaced with a VLLM server and grounder specific processing
|
||||
self.url = ""
|
||||
|
||||
async def predict(
|
||||
self, image_base64: str, action_description: str, action: str | None = None
|
||||
self, image_base64: str, action_description: str, action: str, element_description: str | None = None,
|
||||
) -> utils.GroundingOutput:
|
||||
request = utils.GroundingRequest(
|
||||
description=action_description,
|
||||
image_base64=image_base64,
|
||||
action_type=action,
|
||||
element_description=element_description
|
||||
)
|
||||
api_key = os.getenv("SERVICE_KEY")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.url,
|
||||
@@ -26,6 +25,7 @@ class GrounderClient(object):
|
||||
"image_base64": request.image_base64,
|
||||
"action_description": request.description,
|
||||
"action": request.action_type,
|
||||
"element_description": request.element_description,
|
||||
},
|
||||
headers={
|
||||
"X-API-KEY": api_key
|
||||
@@ -37,6 +37,8 @@ class GrounderClient(object):
|
||||
raise ValueError(f"Prediction failed: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
if tuple(data["position"]) == (-1, -1):
|
||||
raise utils.GroundingOutputValidationException(f"Element {request.description} not found in image", request.description)
|
||||
return utils.GroundingOutput(
|
||||
description=data["description"],
|
||||
position=tuple(data["position"]),
|
||||
|
||||
Reference in New Issue
Block a user