This commit is contained in:
Timothyxxx
2024-03-13 23:35:04 +08:00
parent a7782338d8
commit 741e26c3f8
2 changed files with 8 additions and 0 deletions

View File

@@ -1,3 +1,4 @@
# todo: unifiy all the experiments python file into one file
import datetime import datetime
import json import json
import logging import logging
@@ -114,23 +115,29 @@ def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_tr
try: try:
func_timeout.func_timeout(120, stop_recording) func_timeout.func_timeout(120, stop_recording)
# todo: make sure we got the video file, check the bug
except func_timeout.exceptions.FunctionTimedOut: except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.") logger.info("Recording timed out.")
result = env.evaluate() result = env.evaluate()
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
# fixme: change to write the result into a separate file
with open(trajectory_recording_path, "a") as f: with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({ f.write(json.dumps({
"result": result "result": result
})) }))
f.write("\n") f.write("\n")
# todo: append the result to the wandb for visualization
# env.close() # env.close()
logger.info("Environment closed.") logger.info("Environment closed.")
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"): def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
# fixme: change all the settings like action_space, model, etc. to the argparser
action_space = "pyautogui" action_space = "pyautogui"
gemini_model = "gemini-pro-vision" gemini_model = "gemini-pro-vision"

View File

@@ -30,6 +30,7 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
import logging import logging
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent") logger = logging.getLogger("desktopenv.agent")