Add Pi0 (#681)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
@@ -82,8 +82,13 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
||||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
# for backward compatibility
|
||||
if k.endswith("_is_pad"):
|
||||
continue
|
||||
# for backward compatibility
|
||||
if k == "task":
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
|
||||
@@ -323,6 +323,8 @@ def test_backward_compatibility(repo_id):
|
||||
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
|
||||
new_frame.pop("language_instruction", None)
|
||||
old_frame.pop("language_instruction", None)
|
||||
new_frame.pop("task", None)
|
||||
old_frame.pop("task", None)
|
||||
|
||||
# Remove task_index to allow for backward compatibility
|
||||
# TODO(rcadene): remove when new features have been generated
|
||||
|
||||
@@ -167,14 +167,16 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
|
||||
for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# reset the policy and environment
|
||||
|
||||
Reference in New Issue
Block a user