Add gemini agent implementation; Add missed requirements; Minor fix some small bugs
This commit is contained in:
@@ -6,6 +6,7 @@ import sys
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
from mm_agents.gemini_agent import GeminiPro_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
@@ -44,7 +45,7 @@ logger = logging.getLogger("desktopenv.experiment")
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
@@ -53,7 +54,6 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
observation['instruction'] = example['instruction']
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
@@ -63,17 +63,14 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
step_num += 1
|
||||
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
observation['instruction'] = example['instruction']
|
||||
|
||||
# Logging
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
@@ -114,19 +111,22 @@ def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_tra
|
||||
|
||||
if __name__ == "__main__":
|
||||
action_space = "pyautogui"
|
||||
example_class = "vlc"
|
||||
example_id = "8f080098-ddb1-424c-b438-4e96e5e4786e"
|
||||
example_class = "thunderbird"
|
||||
example_id = "bb5e4c0d-f964-439c-97b6-bdb9747de3f4"
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_setup"
|
||||
example["snapshot"] = "exp_setup2"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, action_space=action_space)
|
||||
# api_key = os.environ.get("OPENAI_API_KEY")
|
||||
# agent = GPT4v_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, example_class, example_id)
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 2, example_trajectory_dir)
|
||||
run_one_example(example, agent, 10, example_trajectory_dir)
|
||||
|
||||
Reference in New Issue
Block a user