Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
Remi
2025-02-04 18:01:04 +01:00
committed by GitHub
parent dd974529cf
commit 638d411cd3
26 changed files with 2365 additions and 92 deletions

View File

@@ -82,8 +82,13 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
batch = next(iter(dataloader))
obs = {}
for k in batch:
# TODO: regenerate the safetensors
# for backward compatibility
if k.endswith("_is_pad"):
continue
# for backward compatibility
if k == "task":
continue
if k.startswith("observation"):
obs[k] = batch[k]

View File

@@ -323,6 +323,8 @@ def test_backward_compatibility(repo_id):
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
new_frame.pop("language_instruction", None)
old_frame.pop("language_instruction", None)
new_frame.pop("task", None)
old_frame.pop("task", None)
# Remove task_index to allow for backward compatibility
# TODO(rcadene): remove when new features have been generated

View File

@@ -167,14 +167,16 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(DEVICE, non_blocking=True)
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) for k in batch
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
# reset the policy and environment