forked from tangger/lerobot
Compare commits
2 Commits
depth
...
thomwolf_2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c2dd1b881 | ||
|
|
63a5d0be39 |
@@ -78,15 +78,29 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
|
|
||||||
image_keys = [key for key in df if "observation.images." in key]
|
image_keys = [key for key in df if "observation.images." in key]
|
||||||
|
|
||||||
|
num_unaligned_images = 0
|
||||||
|
max_episode = 0
|
||||||
|
|
||||||
def get_episode_index(row):
|
def get_episode_index(row):
|
||||||
|
nonlocal num_unaligned_images
|
||||||
|
nonlocal max_episode
|
||||||
episode_index_per_cam = {}
|
episode_index_per_cam = {}
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
|
if isinstance(row[key], float):
|
||||||
|
num_unaligned_images += 1
|
||||||
|
return float("nan")
|
||||||
path = row[key][0]["path"]
|
path = row[key][0]["path"]
|
||||||
match = re.search(r"_(\d{6}).mp4", path)
|
match = re.search(r"_(\d{6}).mp4", path)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(path)
|
raise ValueError(path)
|
||||||
episode_index = int(match.group(1))
|
episode_index = int(match.group(1))
|
||||||
episode_index_per_cam[key] = episode_index
|
episode_index_per_cam[key] = episode_index
|
||||||
|
|
||||||
|
if episode_index > max_episode:
|
||||||
|
assert episode_index - max_episode == 1
|
||||||
|
max_episode = episode_index
|
||||||
|
else:
|
||||||
|
assert episode_index == max_episode
|
||||||
if len(set(episode_index_per_cam.values())) != 1:
|
if len(set(episode_index_per_cam.values())) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||||
@@ -111,11 +125,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
del df["timestamp_utc"]
|
del df["timestamp_utc"]
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
has_nan = df.isna().any().any()
|
num_rows_with_nan = df.isna().any(axis=1).sum()
|
||||||
if has_nan:
|
assert (
|
||||||
raise ValueError("Dataset contains Nan values.")
|
num_rows_with_nan == num_unaligned_images
|
||||||
|
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
|
||||||
|
if num_unaligned_images > max_episode * 2:
|
||||||
|
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
|
||||||
|
# but if there are too many, we raise an error to avoid large chunks of missing data
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
|
||||||
|
f"Check the timestamps of the cameras."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop rows with NaN values now that we double checked and convert episode_index to int
|
||||||
|
df = df.dropna()
|
||||||
|
df["episode_index"] = df["episode_index"].astype(int)
|
||||||
|
|
||||||
# sanity check episode indices go from 0 to n-1
|
# sanity check episode indices go from 0 to n-1
|
||||||
|
assert df["episode_index"].max() == max_episode
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||||
if ep_ids != expected_ep_ids:
|
if ep_ids != expected_ep_ids:
|
||||||
@@ -214,8 +241,6 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
|||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -243,10 +243,11 @@ def load_previous_and_future_frames(
|
|||||||
is_pad = min_ > tolerance_s
|
is_pad = min_ > tolerance_s
|
||||||
|
|
||||||
# check violated query timestamps are all outside the episode range
|
# check violated query timestamps are all outside the episode range
|
||||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
if not ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all():
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
raise ValueError(
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||||
)
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
)
|
||||||
|
|
||||||
# get dataset indices corresponding to frames to be loaded
|
# get dataset indices corresponding to frames to be loaded
|
||||||
data_ids = ep_data_ids[argmin_]
|
data_ids = ep_data_ids[argmin_]
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class Logger:
|
|||||||
training_state["scheduler"] = scheduler.state_dict()
|
training_state["scheduler"] = scheduler.state_dict()
|
||||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||||
|
|
||||||
def save_checkpont(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
train_step: int,
|
train_step: int,
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
|
|||||||
@@ -164,7 +164,10 @@ def rollout(
|
|||||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
# available of none of the envs finished.
|
# available of none of the envs finished.
|
||||||
if "final_info" in info:
|
if "final_info" in info:
|
||||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
successes = [
|
||||||
|
info["is_success"] if info is not None and "is_success" in info else False
|
||||||
|
for info in info["final_info"]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
logger.save_checkpont(
|
logger.save_checkpoint(
|
||||||
step,
|
step,
|
||||||
policy,
|
policy,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
|||||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
item = hf_dataset[2]
|
item = hf_dataset[2]
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(ValueError):
|
||||||
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user