Add gemini agent implementation; Add missed requirements; Minor fix some small bugs

This commit is contained in:
Timothyxxx
2024-01-15 21:58:33 +08:00
parent c68796e842
commit 493b719821
10 changed files with 82 additions and 83 deletions

View File

@@ -6,6 +6,7 @@ import sys
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent
from mm_agents.gemini_agent import GeminiPro_Agent
# Logger Configs {{{ #
logger = logging.getLogger()
@@ -44,7 +45,7 @@ logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_trajectory", recording=True):
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
@@ -53,7 +54,6 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
)
# reset the environment to certain snapshot
observation = env.reset()
observation['instruction'] = example['instruction']
done = False
step_num = 0
@@ -63,17 +63,14 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
step_num += 1
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
observation['instruction'] = example['instruction']
# Logging
logger.info("Step %d: %s", step_num, action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
@@ -114,19 +111,22 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
if __name__ == "__main__":
action_space = "pyautogui"
example_class = "vlc"
example_id = "8f080098-ddb1-424c-b438-4e96e5e4786e"
example_class = "thunderbird"
example_id = "bb5e4c0d-f964-439c-97b6-bdb9747de3f4"
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f:
example = json.load(f)
example["snapshot"] = "exp_setup"
example["snapshot"] = "exp_setup2"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, action_space=action_space)
# api_key = os.environ.get("OPENAI_API_KEY")
# agent = GPT4v_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
api_key = os.environ.get("GENAI_API_KEY")
agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, example_class, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 2, example_trajectory_dir)
run_one_example(example, agent, 10, example_trajectory_dir)