[viz] Fixes & updates to html visualizer (#617)

This commit is contained in:
Mishig
2025-01-09 11:39:54 +01:00
committed by GitHub
parent b8b368310c
commit 25a8597680
2 changed files with 56 additions and 46 deletions

View File

@@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
"""Get a csv str containing timeseries data of an episode (e.g. state and action). """Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time.""" This file will be loaded by Dygraph javascript to plot data in real time."""
columns = [] columns = []
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns.remove("timestamp")
# init header of csv with state and action names # init header of csv with state and action names
header = ["timestamp"] header = ["timestamp"]
if has_state:
for column_name in selected_columns:
dim_state = ( dim_state = (
dataset.meta.shapes["observation.state"][0] dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset) if isinstance(dataset, LeRobotDataset)
else dataset.features["observation.state"].shape[0] else dataset.features[column_name].shape[0]
) )
header += [f"state_{i}" for i in range(dim_state)] header += [f"{column_name}_{i}" for i in range(dim_state)]
column_names = dataset.features["observation.state"]["names"]
while not isinstance(column_names, list): if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = list(column_names.values())[0] column_names = dataset.features[column_name]["names"]
columns.append({"key": "state", "value": column_names}) while not isinstance(column_names, list):
if has_action: column_names = list(column_names.values())[0]
dim_action = ( else:
dataset.meta.shapes["action"][0] column_names = [f"motor_{i}" for i in range(dim_state)]
if isinstance(dataset, LeRobotDataset) columns.append({"key": column_name, "value": column_names})
else dataset.features.action.shape[0]
) selected_columns.insert(0, "timestamp")
header += [f"action_{i}" for i in range(dim_action)]
column_names = dataset.features["action"]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
columns.append({"key": "action", "value": column_names})
if isinstance(dataset, LeRobotDataset): if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index] from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index]
selected_columns = ["timestamp"]
if has_state:
selected_columns += ["observation.state"]
if has_action:
selected_columns += ["action"]
data = ( data = (
dataset.hf_dataset.select(range(from_idx, to_idx)) dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns) .select_columns(selected_columns)
.with_format("numpy") .with_format("pandas")
) )
rows = np.hstack(
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]])
).tolist()
else: else:
repo_id = dataset.repo_id repo_id = dataset.repo_id
selected_columns = ["timestamp"]
if "observation.state" in dataset.features:
selected_columns.append("observation.state")
if "action" in dataset.features:
selected_columns.append("action")
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
) )
df = pd.read_parquet(url) df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns data = df[selected_columns] # Select specific columns
rows = np.hstack(
( rows = np.hstack(
np.expand_dims(data["timestamp"], axis=1), (
*[np.vstack(data[col]) for col in selected_columns[1:]], np.expand_dims(data["timestamp"], axis=1),
) *[np.vstack(data[col]) for col in selected_columns[1:]],
).tolist() )
).tolist()
# Convert data to CSV string # Convert data to CSV string
csv_buffer = StringIO() csv_buffer = StringIO()
@@ -379,10 +364,6 @@ def visualize_dataset_html(
template_folder=template_dir, template_folder=template_dir,
) )
else: else:
image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else []
if len(image_keys) > 0:
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")
# Create a simlink from the dataset video folder containg mp4 files to the output directory # Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files. # so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset): if isinstance(dataset, LeRobotDataset):

View File

@@ -98,9 +98,34 @@
</div> </div>
<!-- Videos --> <!-- Videos -->
<div class="max-w-32 relative text-sm mb-4 select-none"
@click.outside="isVideosDropdownOpen = false">
<div
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
>
<span class="truncate">filter videos</span>
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
</div>
<div x-show="isVideosDropdownOpen"
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
<div>
<template x-for="option in videosKeys" :key="option">
<div
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
class="p-2 cursor-pointer bg-slate-900"
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
x-text="option"
></div>
</template>
</div>
</div>
</div>
<div class="flex flex-wrap gap-x-2 gap-y-6"> <div class="flex flex-wrap gap-x-2 gap-y-6">
{% for video_info in videos_info %} {% for video_info in videos_info %}
<div x-show="!videoCodecError" class="max-w-96 relative"> <div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p> <p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => { <video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
if (video.duration) { if (video.duration) {
@@ -250,6 +275,9 @@
nVideos: {{ videos_info | length }}, nVideos: {{ videos_info | length }},
nVideoReadyToPlay: 0, nVideoReadyToPlay: 0,
videoCodecError: false, videoCodecError: false,
isVideosDropdownOpen: false,
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
videosKeysSelected: [],
columns: {{ columns | tojson }}, columns: {{ columns | tojson }},
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value, rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
@@ -261,6 +289,7 @@
if(!canPlayVideos){ if(!canPlayVideos){
this.videoCodecError = true; this.videoCodecError = true;
} }
this.videosKeysSelected = this.videosKeys.map(opt => opt)
// process CSV data // process CSV data
const csvDataStr = {{ episode_data_csv_str|tojson|safe }}; const csvDataStr = {{ episode_data_csv_str|tojson|safe }};