feat&fix: update configuration management to save model arguments and enhance UI display for model args (#262)
This commit is contained in:
@@ -43,7 +43,7 @@ MAX_STEPS = int(os.getenv("MAX_STEPS", "150"))
|
||||
|
||||
def initialize_default_config():
|
||||
"""Initialize default configuration from the first available config in results directory"""
|
||||
global ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME, RESULTS_PATH
|
||||
global ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME, RESULTS_PATH, MAX_STEPS
|
||||
|
||||
if os.path.exists(RESULTS_BASE_PATH):
|
||||
try:
|
||||
@@ -62,14 +62,20 @@ def initialize_default_config():
|
||||
OBSERVATION_TYPE = obs_type
|
||||
MODEL_NAME = model_name
|
||||
RESULTS_PATH = model_path
|
||||
print(f"Initialized default config: {ACTION_SPACE}/{OBSERVATION_TYPE}/{MODEL_NAME}")
|
||||
|
||||
# Read max_steps from args.json if available
|
||||
model_args = get_model_args(action_space, obs_type, model_name)
|
||||
if model_args and 'max_steps' in model_args:
|
||||
MAX_STEPS = model_args['max_steps']
|
||||
|
||||
print(f"Initialized default config: {ACTION_SPACE}/{OBSERVATION_TYPE}/{MODEL_NAME} (max_steps: {MAX_STEPS})")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error scanning results directory for default config: {e}")
|
||||
|
||||
# Fallback to original environment-based path if no configs found
|
||||
RESULTS_PATH = os.path.join(RESULTS_BASE_PATH, ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
|
||||
print(f"Using fallback config from environment: {ACTION_SPACE}/{OBSERVATION_TYPE}/{MODEL_NAME}")
|
||||
print(f"Using fallback config from environment: {ACTION_SPACE}/{OBSERVATION_TYPE}/{MODEL_NAME} (max_steps: {MAX_STEPS})")
|
||||
|
||||
# Initialize default configuration
|
||||
initialize_default_config()
|
||||
@@ -522,19 +528,28 @@ def api_available_configs():
|
||||
|
||||
@app.route('/api/current-config')
|
||||
def api_current_config():
|
||||
"""Get current configuration"""
|
||||
return jsonify({
|
||||
"""Get current configuration including args.json data"""
|
||||
config = {
|
||||
"action_space": ACTION_SPACE,
|
||||
"observation_type": OBSERVATION_TYPE,
|
||||
"model_name": MODEL_NAME,
|
||||
"max_steps": MAX_STEPS,
|
||||
"results_path": RESULTS_PATH
|
||||
})
|
||||
}
|
||||
|
||||
# Add model args from args.json
|
||||
model_args = get_model_args(ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
|
||||
if model_args:
|
||||
config["model_args"] = model_args
|
||||
else:
|
||||
config["model_args"] = {}
|
||||
|
||||
return jsonify(config)
|
||||
|
||||
@app.route('/api/set-config', methods=['POST'])
|
||||
def api_set_config():
|
||||
"""Set current configuration"""
|
||||
global ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME, RESULTS_PATH
|
||||
global ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME, RESULTS_PATH, MAX_STEPS
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
@@ -548,6 +563,11 @@ def api_set_config():
|
||||
# Update results path
|
||||
RESULTS_PATH = os.path.join(RESULTS_BASE_PATH, ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
|
||||
|
||||
# Update max_steps from args.json if available
|
||||
model_args = get_model_args(ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
|
||||
if model_args and 'max_steps' in model_args:
|
||||
MAX_STEPS = model_args['max_steps']
|
||||
|
||||
if RESULTS_PATH not in TASK_STATUS_CACHE:
|
||||
# Initialize cache for this results path
|
||||
TASK_STATUS_CACHE[RESULTS_PATH] = {}
|
||||
@@ -560,6 +580,17 @@ def api_set_config():
|
||||
"results_path": RESULTS_PATH
|
||||
})
|
||||
|
||||
def get_model_args(action_space, observation_type, model_name):
|
||||
"""Get model arguments from args.json file"""
|
||||
args_file = os.path.join(RESULTS_BASE_PATH, action_space, observation_type, model_name, "args.json")
|
||||
if os.path.exists(args_file):
|
||||
try:
|
||||
with open(args_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error reading args.json: {e}")
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Check if necessary directories exist
|
||||
if not os.path.exists(TASK_CONFIG_PATH):
|
||||
|
||||
Reference in New Issue
Block a user