Add Pi0 (#681)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
@@ -82,8 +82,13 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
||||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
# for backward compatibility
|
||||
if k.endswith("_is_pad"):
|
||||
continue
|
||||
# for backward compatibility
|
||||
if k == "task":
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user