[viz tool] add policy pred column
This commit is contained in:
committed by
Remi Cadene
parent
272a9d9427
commit
a97c1cb1af
@@ -88,6 +88,7 @@ def run_server(
|
|||||||
port: str,
|
port: str,
|
||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
template_folder: Path,
|
template_folder: Path,
|
||||||
|
has_policy = False,
|
||||||
):
|
):
|
||||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
@@ -130,7 +131,7 @@ def run_server(
|
|||||||
dataset_info=dataset_info,
|
dataset_info=dataset_info,
|
||||||
videos_info=videos_info,
|
videos_info=videos_info,
|
||||||
ep_csv_url=ep_csv_url,
|
ep_csv_url=ep_csv_url,
|
||||||
has_policy=False,
|
has_policy = has_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.run(host=host, port=port)
|
app.run(host=host, port=port)
|
||||||
@@ -344,7 +345,7 @@ def visualize_dataset_html(
|
|||||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
|
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy=policy is not None)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -229,7 +229,8 @@
|
|||||||
dygraph: null,
|
dygraph: null,
|
||||||
currentFrameData: null,
|
currentFrameData: null,
|
||||||
columnNames: ["state", "action", "pred action"],
|
columnNames: ["state", "action", "pred action"],
|
||||||
nColumns: 2,
|
hasPolicy: {% if has_policy %}true{% else %}false{% endif %},
|
||||||
|
nColumns: {% if has_policy %}3{% else %}2{% endif %},
|
||||||
nStates: 0,
|
nStates: 0,
|
||||||
nActions: 0,
|
nActions: 0,
|
||||||
checked: [],
|
checked: [],
|
||||||
@@ -278,6 +279,9 @@
|
|||||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||||
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
||||||
this.nActions = seriesNames.length - this.nStates;
|
this.nActions = seriesNames.length - this.nStates;
|
||||||
|
if(this.hasPolicy){
|
||||||
|
this.nActions = Math.floor(this.nActions / 2);
|
||||||
|
}
|
||||||
const colors = [];
|
const colors = [];
|
||||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||||
// colors for "state" lines
|
// colors for "state" lines
|
||||||
@@ -290,6 +294,13 @@
|
|||||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
||||||
colors.push(color);
|
colors.push(color);
|
||||||
}
|
}
|
||||||
|
if(this.hasPolicy){
|
||||||
|
// colors for "action" lines
|
||||||
|
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
||||||
|
const color = `hsl(${hue}, 100%, ${LIGHTNESS[2]}%)`;
|
||||||
|
colors.push(color);
|
||||||
|
}
|
||||||
|
}
|
||||||
this.dygraph.updateOptions({ colors });
|
this.dygraph.updateOptions({ colors });
|
||||||
this.colors = colors;
|
this.colors = colors;
|
||||||
|
|
||||||
@@ -327,6 +338,10 @@
|
|||||||
// row consists of [state value, action value]
|
// row consists of [state value, action value]
|
||||||
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
||||||
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
||||||
|
if(this.hasPolicy){
|
||||||
|
const predActionValueIdx = stateValueIdx + this.nStates + this.nActions; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN, pred_action1, ..., pred_actionN]
|
||||||
|
row.push(rowIndex < this.nActions ? this.currentFrameData[predActionValueIdx] : nullCell); // push "action value" to row
|
||||||
|
}
|
||||||
rowIndex += 1;
|
rowIndex += 1;
|
||||||
rows.push(row);
|
rows.push(row);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user