[viz tool] add policy pred column

This commit is contained in:
Mishig Davaadorj
2024-11-27 15:32:48 +01:00
committed by Remi Cadene
parent 272a9d9427
commit a97c1cb1af
2 changed files with 19 additions and 3 deletions

View File

@@ -88,6 +88,7 @@ def run_server(
port: str, port: str,
static_folder: Path, static_folder: Path,
template_folder: Path, template_folder: Path,
has_policy = False,
): ):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@@ -130,7 +131,7 @@ def run_server(
dataset_info=dataset_info, dataset_info=dataset_info,
videos_info=videos_info, videos_info=videos_info,
ep_csv_url=ep_csv_url, ep_csv_url=ep_csv_url,
has_policy=False, has_policy = has_policy,
) )
app.run(host=host, port=port) app.run(host=host, port=port)
@@ -344,7 +345,7 @@ def visualize_dataset_html(
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy) write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
if serve: if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir) run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy=policy is not None)
def main(): def main():

View File

@@ -229,7 +229,8 @@
dygraph: null, dygraph: null,
currentFrameData: null, currentFrameData: null,
columnNames: ["state", "action", "pred action"], columnNames: ["state", "action", "pred action"],
nColumns: 2, hasPolicy: {% if has_policy %}true{% else %}false{% endif %},
nColumns: {% if has_policy %}3{% else %}2{% endif %},
nStates: 0, nStates: 0,
nActions: 0, nActions: 0,
checked: [], checked: [],
@@ -278,6 +279,9 @@
const seriesNames = this.dygraph.getLabels().slice(1); const seriesNames = this.dygraph.getLabels().slice(1);
this.nStates = seriesNames.findIndex(item => item.startsWith('action_')); this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
this.nActions = seriesNames.length - this.nStates; this.nActions = seriesNames.length - this.nStates;
if(this.hasPolicy){
this.nActions = Math.floor(this.nActions / 2);
}
const colors = []; const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
// colors for "state" lines // colors for "state" lines
@@ -290,6 +294,13 @@
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`; const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
colors.push(color); colors.push(color);
} }
if(this.hasPolicy){
// colors for "action" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[2]}%)`;
colors.push(color);
}
}
this.dygraph.updateOptions({ colors }); this.dygraph.updateOptions({ colors });
this.colors = colors; this.colors = colors;
@@ -327,6 +338,10 @@
// row consists of [state value, action value] // row consists of [state value, action value]
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
if(this.hasPolicy){
const predActionValueIdx = stateValueIdx + this.nStates + this.nActions; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN, pred_action1, ..., pred_actionN]
row.push(rowIndex < this.nActions ? this.currentFrameData[predActionValueIdx] : nullCell); // push "action value" to row
}
rowIndex += 1; rowIndex += 1;
rows.push(row); rows.push(row);
} }