[viz] Fixes & updates to html visualizer (#617)

This commit is contained in:
Mishig
2025-01-09 11:39:54 +01:00
committed by GitHub
parent b8b368310c
commit 25a8597680
2 changed files with 56 additions and 46 deletions

View File

@@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns.remove("timestamp")
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes["observation.state"][0]
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features["observation.state"].shape[0]
else dataset.features[column_name].shape[0]
)
header += [f"state_{i}" for i in range(dim_state)]
column_names = dataset.features["observation.state"]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
columns.append({"key": "state", "value": column_names})
if has_action:
dim_action = (
dataset.meta.shapes["action"][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features.action.shape[0]
)
header += [f"action_{i}" for i in range(dim_action)]
column_names = dataset.features["action"]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
columns.append({"key": "action", "value": column_names})
header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"motor_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
selected_columns = ["timestamp"]
if has_state:
selected_columns += ["observation.state"]
if has_action:
selected_columns += ["action"]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("numpy")
.with_format("pandas")
)
rows = np.hstack(
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]])
).tolist()
else:
repo_id = dataset.repo_id
selected_columns = ["timestamp"]
if "observation.state" in dataset.features:
selected_columns.append("observation.state")
if "action" in dataset.features:
selected_columns.append("action")
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
@@ -379,10 +364,6 @@ def visualize_dataset_html(
template_folder=template_dir,
)
else:
image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else []
if len(image_keys) > 0:
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):