fix error
This commit is contained in:
@@ -63,6 +63,33 @@ def setup_logger(example, example_result_dir):
|
||||
runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
|
||||
return runtime_logger
|
||||
|
||||
def run_single_example_human(env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
||||
# Save initial screenshot
|
||||
with open(os.path.join(example_result_dir, "initial_state.png"), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Save trajectory information
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"initial_state": "initial_state.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# Evaluate the result
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
|
||||
|
||||
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
|
||||
Reference in New Issue
Block a user