From 32e3f71dd18021127e03f841bb819edb22bc0084 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 09:17:02 +0000 Subject: [PATCH 1/8] backup wip --- .../diffusion/diffusion_unet_image_policy.py | 3 +- .../model/multi_image_obs_encoder.py | 57 ++++++++++++++----- lerobot/common/policies/diffusion/policy.py | 6 +- lerobot/common/policies/factory.py | 19 +++++++ lerobot/configs/policy/diffusion.yaml | 16 +++--- 5 files changed, 75 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index b759802e..c5b00d94 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -190,11 +190,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): # run sampling nsample = self.conditional_sample( - cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs + cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond ) action_pred = nsample[..., :action_dim] - # get action start = n_obs_steps - 1 end = start + self.n_action_steps diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 94dc6f49..17252c1c 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,15 +1,40 @@ import copy -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union +import timm import torch import torch.nn as nn import torchvision +from robomimic.models.base_nets import SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules +class RgbEncoder(nn.Module): + """Following `VisualCore` from Robomimic 0.2.0.""" + + def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): + """ + input_shape: channel-first input shape (C, H, W) + resnet_name: a timm model name. + pretrained: whether to use timm pretrained weights. + num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). + """ + super().__init__() + self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") + # self.backbone = ResNet18Conv(input_channel=input_shape[0]) + # Figure out the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) + self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) + + def forward(self, x): + return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)) + + class MultiImageObsEncoder(ModuleAttrMixin): def __init__( self, @@ -24,7 +49,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - imagenet_norm: bool = False, + norm_mean_std: Optional[tuple[float, float]] = None, ): """ Assumes rgb input: B,C,H,W @@ -98,10 +123,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if imagenet_norm: - # TODO(rcadene): move normalizer to dataset and env + if norm_mean_std is not None: this_normalizer = torchvision.transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + mean=norm_mean_std[0], std=norm_mean_std[1] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) @@ -124,6 +148,17 @@ class MultiImageObsEncoder(ModuleAttrMixin): def forward(self, obs_dict): batch_size = None features = [] + + # process lowdim input + for key in self.low_dim_keys: + data = obs_dict[key] + if batch_size is None: + batch_size = data.shape[0] + else: + assert batch_size == data.shape[0] + assert data.shape[1:] == self.key_shape_map[key] + features.append(data) + # process rgb input if self.share_rgb_model: # pass all rgb obs to rgb model @@ -147,6 +182,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): feature = torch.moveaxis(feature, 0, 1) # (B,N*D) feature = feature.reshape(batch_size, -1) + # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) else: # run each rgb obs to independent models @@ -159,18 +195,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) + # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) - # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) - # concatenate all features result = torch.cat(features, dim=-1) return result diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 2c47f172..f68ffb8e 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -7,7 +7,7 @@ import torch from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder class DiffusionPolicy(AbstractPolicy): @@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy): self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) + rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) + if cfg_obs_encoder.crop_shape is not None: + rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape + rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 085baab5..7961beed 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -40,4 +40,23 @@ def make_policy(cfg): raise NotImplementedError() policy.load(cfg.policy.pretrained_model_path) + # import torch + # loaded = torch.load('/home/alexander/Downloads/dp_ema.pth') + # aligned = {} + + # their_prefix = "obs_encoder.obs_nets.image.backbone" + # our_prefix = "obs_encoder.key_model_map.image.backbone" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + # their_prefix = "obs_encoder.obs_nets.image.pool" + # our_prefix = "obs_encoder.key_model_map.image.pool" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + # their_prefix = "obs_encoder.obs_nets.image.nets.3" + # our_prefix = "obs_encoder.key_model_map.image.out" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + + # aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) + # missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False) + # assert all('_dummy_variable' in k for k in missing_keys) + # assert len(unexpected_keys) == 0 + return policy diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056..2b63f7e1 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 256 # before 128 - down_dims: [256, 512, 1024] # before [512, 1024, 2048] + diffusion_step_embed_dim: 128 + down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -76,17 +76,17 @@ noise_scheduler: obs_encoder: shape_meta: ${shape_meta} # resize_shape: null - # crop_shape: [76, 76] + crop_shape: [84, 84] # constant center crop - # random_crop: True + random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: True + norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: - _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet - name: resnet18 - weights: null + model_name: resnet18 + pretrained: false + num_keypoints: 32 ema: _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel From d3239935691307c8f8778fe26ebd9c908899f5df Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 15:01:27 +0000 Subject: [PATCH 2/8] backup wip --- lerobot/common/envs/pusht/pusht_image_env.py | 17 +++-- .../model/multi_image_obs_encoder.py | 9 +-- lerobot/common/policies/diffusion/policy.py | 33 ++++++---- lerobot/common/policies/factory.py | 3 +- lerobot/configs/default.yaml | 6 +- lerobot/configs/policy/diffusion.yaml | 21 +++---- lerobot/scripts/train.py | 63 ++++++++----------- 7 files changed, 71 insertions(+), 81 deletions(-) diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index 0807e849..b30ad874 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -1,4 +1,3 @@ -import cv2 import numpy as np from gym import spaces @@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv): coord = (action / 512 * 96).astype(np.int32) marker_size = int(8 / 96 * self.render_size) thickness = int(1 / 96 * self.render_size) - cv2.drawMarker( - img, - coord, - color=(255, 0, 0), - markerType=cv2.MARKER_CROSS, - markerSize=marker_size, - thickness=thickness, - ) + # cv2.drawMarker( + # img, + # coord, + # color=(255, 0, 0), + # markerType=cv2.MARKER_CROSS, + # markerSize=marker_size, + # thickness=thickness, + # ) self.render_cache = img return obs diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 17252c1c..c7b9807d 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules class RgbEncoder(nn.Module): """Following `VisualCore` from Robomimic 0.2.0.""" - def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): + def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32): """ input_shape: channel-first input shape (C, H, W) resnet_name: a timm model name. pretrained: whether to use timm pretrained weights. + rele: whether to use relu as a final step. num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). """ super().__init__() @@ -30,9 +31,11 @@ class RgbEncoder(nn.Module): feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) + self.relu = nn.ReLU() if relu else nn.Identity() def forward(self, x): - return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)) + # TODO(now): make nonlinearity optional + return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) class MultiImageObsEncoder(ModuleAttrMixin): @@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin): feature = torch.moveaxis(feature, 0, 1) # (B,N*D) feature = feature.reshape(batch_size, -1) - # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) else: # run each rgb obs to independent models @@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) - # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) # concatenate all features diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index f68ffb8e..a4185afc 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,9 +1,11 @@ import copy +import logging import time import hydra import torch +from lerobot.common.ema import update_ema_parameters from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler @@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy): cfg_rgb_model, cfg_obs_encoder, cfg_optimizer, - cfg_ema, shape_meta: dict, horizon, n_action_steps, @@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy): if cfg_obs_encoder.crop_shape is not None: rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) - rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, @@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy): if torch.cuda.is_available() and cfg_device == "cuda": self.diffusion.cuda() - self.ema = None - if self.cfg.use_ema: - self.ema = hydra.utils.instantiate( - cfg_ema, - model=copy.deepcopy(self.diffusion), - ) + self.ema_diffusion = None + if self.cfg.ema.enable: + self.ema_diffusion = copy.deepcopy(self.diffusion) self.optimizer = hydra.utils.instantiate( cfg_optimizer, @@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy): @torch.no_grad() def select_actions(self, observation, step_count): + """ + Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. + """ # TODO(rcadene): remove unused step_count del step_count @@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy): "image": observation["image"], "agent_pos": observation["state"], } - out = self.diffusion.predict_action(obs_dict) + if self.training: + out = self.diffusion.predict_action(obs_dict) + else: + out = self.ema_diffusion.predict_action(obs_dict) action = out["action"] return action @@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy): self.optimizer.zero_grad() self.lr_scheduler.step() - if self.ema is not None: - self.ema.step(self.diffusion) + if self.cfg.ema.enable: + update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate) info = { "loss": loss.item(), @@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy): def load(self, fp): d = torch.load(fp) - self.load_state_dict(d) + missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) + if len(missing_keys) > 0: + assert all(k.startswith("ema_diffusion.") for k in missing_keys) + logging.warning( + "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." + ) + assert len(unexpected_keys) == 0 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 7961beed..32a366b3 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,7 +16,6 @@ def make_policy(cfg): cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) @@ -41,7 +40,7 @@ def make_policy(cfg): policy.load(cfg.policy.pretrained_model_path) # import torch - # loaded = torch.load('/home/alexander/Downloads/dp_ema.pth') + # loaded = torch.load('/home/alexander/Downloads/dp.pth') # aligned = {} # their_prefix = "obs_encoder.obs_nets.image.backbone" diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 52fd1d60..90d4c06b 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -12,14 +12,14 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # NOTE: only diffusion policy supports rollout_batch_size > 1 -rollout_batch_size: 1 +rollout_batch_size: 10 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: false +save_model: true save_buffer: false train_steps: ??? fps: ??? @@ -34,6 +34,6 @@ policy: ??? wandb: enable: true # Set to true to disable saving an artifact despite save_model == True - disable_artifact: false + disable_artifact: true project: lerobot notes: "" diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 2b63f7e1..a81952e0 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -21,12 +21,12 @@ past_action_visible: False keypoint_visible_rate: 1.0 obs_as_global_cond: True -eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 +eval_episodes: 50 +eval_freq: 5000 +save_freq: 5000 log_freq: 250 -offline_steps: 1344000 +offline_steps: 50000 online_steps: 0 offline_prioritized_sampler: true @@ -58,7 +58,9 @@ policy: balanced_sampling: false utd: 1 offline_steps: ${offline_steps} - use_ema: true + ema: + enable: true + rate: 0.999 lr_scheduler: cosine lr_warmup_steps: 500 grad_clip_norm: 10 @@ -87,14 +89,7 @@ rgb_model: model_name: resnet18 pretrained: false num_keypoints: 32 - -ema: - _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel - update_after_step: 0 - inv_gamma: 1.0 - power: 0.75 - min_value: 0.0 - max_value: 0.9999 + relu: true optimizer: _target_: torch.optim.AdamW diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ecd616d..a2039006 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -155,11 +155,7 @@ def train(cfg: dict, out_dir=None, job_name=None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - td_policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) + td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"]) # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) @@ -174,19 +170,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - step = 0 # number of policy update (forward + backward + optim) - - is_offline = True - for offline_step in range(cfg.offline_steps): - if offline_step == 0: - logging.info("Start offline training on a fixed dataset") - # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? - policy.train() - train_info = policy.update(offline_buffer, step) - if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) - - if step > 0 and step % cfg.eval_freq == 0: + # Note: this helper will be used in offline and online training loops. + def _maybe_eval_and_maybe_save(step): + if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, @@ -202,11 +188,27 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.log_video(first_video, step, mode="eval") logging.info("Resume training") - if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint policy at step {step}") + if cfg.save_model and step % cfg.save_freq == 0: + logging.info(f"Checkpoint policy after step {step}") logger.save_model(policy, identifier=step) logging.info("Resume training") + step = 0 # number of policy update (forward + backward + optim) + + is_offline = True + for offline_step in range(cfg.offline_steps): + if offline_step == 0: + logging.info("Start offline training on a fixed dataset") + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? + policy.train() + train_info = policy.update(offline_buffer, step) + if step % cfg.log_freq == 0: + log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + + # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in + # step + 1. + _maybe_eval_and_maybe_save(step + 1) + step += 1 demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None @@ -248,24 +250,9 @@ def train(cfg: dict, out_dir=None, job_name=None): train_info.update(rollout_info) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) - if step > 0 and step % cfg.eval_freq == 0: - logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( - env, - td_policy, - num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length // cfg.n_action_steps, - return_first_video=True, - ) - log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) - if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") - logging.info("Resume training") - - if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint policy at step {step}") - logger.save_model(policy, identifier=step) - logging.info("Resume training") + # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass + # in step + 1. + _maybe_eval_and_maybe_save(step + 1) step += 1 online_step += 1 From acf1174447ff2ffa8616280e0390c38020990738 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 10:18:50 +0000 Subject: [PATCH 3/8] ready for review --- lerobot/common/envs/aloha/env.py | 2 +- lerobot/common/envs/factory.py | 24 --- lerobot/common/envs/pusht/env.py | 26 ++- lerobot/common/envs/pusht/pusht_image_env.py | 14 -- .../diffusion/diffusion_unet_image_policy.py | 41 ++++ .../model/multi_image_obs_encoder.py | 11 +- lerobot/common/policies/diffusion/policy.py | 13 +- lerobot/common/policies/factory.py | 20 +- lerobot/configs/policy/diffusion.yaml | 16 +- lerobot/scripts/eval.py | 1 + poetry.lock | 198 +++++++++++++++++- pyproject.toml | 1 + 12 files changed, 282 insertions(+), 85 deletions(-) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index e09564fb..af2b354b 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -192,7 +192,7 @@ class AlohaEnv(AbstractEnv): { "observation": TensorDict(obs, batch_size=[]), "reward": torch.tensor([reward], dtype=torch.float32), - # succes and done are true when coverage > self.success_threshold in env + # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), }, diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index e187d713..06c7c43f 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -62,27 +62,3 @@ def make_env(cfg, transform=None): {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) ], ) - - -# def make_env(env_name, frame_skip, device, is_test=False): -# env = GymEnv( -# env_name, -# frame_skip=frame_skip, -# from_pixels=True, -# pixels_only=False, -# device=device, -# ) -# env = TransformedEnv(env) -# env.append_transform(NoopResetEnv(noops=30, random=True)) -# if not is_test: -# env.append_transform(EndOfLifeTransform()) -# env.append_transform(RewardClipping(-1, 1)) -# env.append_transform(ToTensorImage()) -# env.append_transform(GrayScale()) -# env.append_transform(Resize(84, 84)) -# env.append_transform(CatFrames(N=4, dim=-3)) -# env.append_transform(RewardSum()) -# env.append_transform(StepCounter(max_steps=4500)) -# env.append_transform(DoubleToFloat()) -# env.append_transform(VecNorm(in_keys=["pixels"])) -# return env diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index f440d443..3824a5d2 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -3,6 +3,8 @@ import logging from collections import deque from typing import Optional +import cv2 +import numpy as np import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv): self._env = PushTImageEnv(render_size=self.image_size) - def render(self, mode="rgb_array", width=384, height=384): + def render(self, mode="rgb_array", width=96, height=96, with_marker=True): + """ + with_marker adds a cursor showing the targeted action for the controller. + """ if width != height: raise NotImplementedError() tmp = self._env.render_size - self._env.render_size = width - out = self._env.render(mode) + if width != self._env.render_size: + self._env.render_cache = None + self._env.render_size = width + out = self._env.render(mode).copy() + if with_marker and self._env.latest_action is not None: + action = np.array(self._env.latest_action) + coord = (action / 512 * self._env.render_size).astype(np.int32) + marker_size = int(8 / 96 * self._env.render_size) + thickness = int(1 / 96 * self._env.render_size) + cv2.drawMarker( + out, + coord, + color=(255, 0, 0), + markerType=cv2.MARKER_CROSS, + markerSize=marker_size, + thickness=thickness, + ) self._env.render_size = tmp return out diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index b30ad874..ec8e177b 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -27,20 +27,6 @@ class PushTImageEnv(PushTEnv): img_obs = np.moveaxis(img, -1, 0) obs = {"image": img_obs, "agent_pos": agent_pos} - # draw action - if self.latest_action is not None: - action = np.array(self.latest_action) - coord = (action / 512 * 96).astype(np.int32) - marker_size = int(8 / 96 * self.render_size) - thickness = int(1 / 96 * self.render_size) - # cv2.drawMarker( - # img, - # coord, - # color=(255, 0, 0), - # markerType=cv2.MARKER_CROSS, - # markerSize=marker_size, - # thickness=thickness, - # ) self.render_cache = img return obs diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index c5b00d94..7719fdde 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -1,3 +1,44 @@ +"""Code from the original diffusion policy project. + +Notes on how to load a checkpoint from the original repository: + +In the original repository, run the eval and use a breakpoint to extract the policy weights. + +``` +torch.save(policy.state_dict(), "weights.pt") +``` + +In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights: + +``` +loaded = torch.load("weights.pt") +aligned = {} +their_prefix = "obs_encoder.obs_nets.image.backbone" +our_prefix = "obs_encoder.key_model_map.image.backbone" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +their_prefix = "obs_encoder.obs_nets.image.pool" +our_prefix = "obs_encoder.key_model_map.image.pool" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +their_prefix = "obs_encoder.obs_nets.image.nets.3" +our_prefix = "obs_encoder.key_model_map.image.out" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) +# Note: here you are loading into the ema model. +missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False) +assert all('_dummy_variable' in k for k in missing_keys) +assert len(unexpected_keys) == 0 +``` + +Then in that same runtime you can also save the weights with the new aligned state_dict: + +``` +policy.save("weights.pt") +``` + +Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. + +""" + from typing import Dict import torch diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index c7b9807d..d724cd49 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,11 +1,10 @@ import copy from typing import Dict, Optional, Tuple, Union -import timm import torch import torch.nn as nn import torchvision -from robomimic.models.base_nets import SpatialSoftmax +from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin @@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules class RgbEncoder(nn.Module): """Following `VisualCore` from Robomimic 0.2.0.""" - def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32): + def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32): """ input_shape: channel-first input shape (C, H, W) resnet_name: a timm model name. pretrained: whether to use timm pretrained weights. - rele: whether to use relu as a final step. + relu: whether to use relu as a final step. num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). """ super().__init__() - self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") - # self.backbone = ResNet18Conv(input_channel=input_shape[0]) + self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained) # Figure out the feature map shape. with torch.inference_mode(): feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) @@ -34,7 +32,6 @@ class RgbEncoder(nn.Module): self.relu = nn.ReLU() if relu else nn.Identity() def forward(self, x): - # TODO(now): make nonlinearity optional return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a4185afc..1b3b24b6 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -5,7 +5,6 @@ import time import hydra import torch -from lerobot.common.ema import update_ema_parameters from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler @@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy): cfg_rgb_model, cfg_obs_encoder, cfg_optimizer, + cfg_ema, shape_meta: dict, horizon, n_action_steps, @@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy): self.diffusion.cuda() self.ema_diffusion = None - if self.cfg.ema.enable: + self.ema = None + if self.cfg.use_ema: self.ema_diffusion = copy.deepcopy(self.diffusion) + self.ema = hydra.utils.instantiate( + cfg_ema, + model=self.ema_diffusion, + ) self.optimizer = hydra.utils.instantiate( cfg_optimizer, @@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy): self.optimizer.zero_grad() self.lr_scheduler.step() - if self.cfg.ema.enable: - update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate) + if self.ema is not None: + self.ema.step(self.diffusion) info = { "loss": loss.item(), diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 32a366b3..085baab5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,6 +16,7 @@ def make_policy(cfg): cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) @@ -39,23 +40,4 @@ def make_policy(cfg): raise NotImplementedError() policy.load(cfg.policy.pretrained_model_path) - # import torch - # loaded = torch.load('/home/alexander/Downloads/dp.pth') - # aligned = {} - - # their_prefix = "obs_encoder.obs_nets.image.backbone" - # our_prefix = "obs_encoder.key_model_map.image.backbone" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - # their_prefix = "obs_encoder.obs_nets.image.pool" - # our_prefix = "obs_encoder.key_model_map.image.pool" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - # their_prefix = "obs_encoder.obs_nets.image.nets.3" - # our_prefix = "obs_encoder.key_model_map.image.out" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - - # aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) - # missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False) - # assert all('_dummy_variable' in k for k in missing_keys) - # assert len(unexpected_keys) == 0 - return policy diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index a81952e0..acb368ed 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -12,6 +12,7 @@ shape_meta: action: shape: [2] +seed: 100000 horizon: 16 n_obs_steps: 2 n_action_steps: 8 @@ -26,7 +27,7 @@ eval_freq: 5000 save_freq: 5000 log_freq: 250 -offline_steps: 50000 +offline_steps: 200000 online_steps: 0 offline_prioritized_sampler: true @@ -58,9 +59,7 @@ policy: balanced_sampling: false utd: 1 offline_steps: ${offline_steps} - ema: - enable: true - rate: 0.999 + use_ema: true lr_scheduler: cosine lr_warmup_steps: 500 grad_clip_norm: 10 @@ -86,11 +85,18 @@ obs_encoder: norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: - model_name: resnet18 pretrained: false num_keypoints: 32 relu: true +ema: + _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + optimizer: _target_: torch.optim.AdamW lr: 1.0e-4 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 41d58b91..c0c34629 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -50,6 +50,7 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 + # TODO now: generalize kwarg or maybe just remove it ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): diff --git a/poetry.lock b/poetry.lock index ddb0a0e3..92449d45 100644 --- a/poetry.lock +++ b/poetry.lock @@ -604,6 +604,16 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "egl-probe" +version = "1.0.2" +description = "" +optional = false +python-versions = "*" +files = [ + {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, +] + [[package]] name = "einops" version = "0.7.0" @@ -763,6 +773,72 @@ files = [ [package.extras] preview = ["glfw-preview"] +[[package]] +name = "grpcio" +version = "1.62.1" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, + {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, + {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, + {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, + {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, + {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, + {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, + {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, + {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, + {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, + {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, + {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, + {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, + {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, + {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, + {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, + {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, + {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, + {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, + {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, + {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, + {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, + {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, + {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.62.1)"] + [[package]] name = "gym" version = "0.26.2" @@ -1038,13 +1114,13 @@ setuptools = "*" [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -1053,7 +1129,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -1265,6 +1341,21 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.7)"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2460,6 +2551,30 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "robomimic" +version = "0.2.0" +description = "robomimic: A Modular Framework for Robot Learning from Demonstration" +optional = false +python-versions = ">=3" +files = [ + {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, +] + +[package.dependencies] +egl_probe = ">=1.0.1" +h5py = "*" +imageio = "*" +imageio-ffmpeg = "*" +numpy = ">=1.13.3" +psutil = "*" +tensorboard = "*" +tensorboardX = "*" +termcolor = "*" +torch = "*" +torchvision = "*" +tqdm = "*" + [[package]] name = "safetensors" version = "0.4.2" @@ -2684,13 +2799,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.42.0" +version = "1.43.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, - {file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, + {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"}, + {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"}, ] [package.dependencies] @@ -2704,6 +2819,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -2948,6 +3064,55 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + +[[package]] +name = "tensorboardx" +version = "2.6.2.2" +description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" +optional = false +python-versions = "*" +files = [ + {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, + {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +protobuf = ">=3.20" + [[package]] name = "tensordict" version = "0.4.0+ca4256e" @@ -3289,6 +3454,23 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "zarr" version = "2.17.1" @@ -3328,4 +3510,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9" +content-hash = "1a45c808e1c48bcbf4319d4cf6876771b7d50f40a5a8968a8b7f3af36192bf34" diff --git a/pyproject.toml b/pyproject.toml index 2e818a44..7e9996a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ torchvision = "^0.17.1" h5py = "^3.10.0" dm-control = "1.0.14" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} +robomimic = "0.2.0" [tool.poetry.group.dev.dependencies] From 4e10cd306b00e07b06b0b761f1d42579182f6db7 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 10:27:07 +0000 Subject: [PATCH 4/8] revert changes to default.yaml --- lerobot/configs/default.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 90d4c06b..52fd1d60 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -12,14 +12,14 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index # NOTE: only diffusion policy supports rollout_batch_size > 1 -rollout_batch_size: 10 +rollout_batch_size: 1 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: true +save_model: false save_buffer: false train_steps: ??? fps: ??? @@ -34,6 +34,6 @@ policy: ??? wandb: enable: true # Set to true to disable saving an artifact despite save_model == True - disable_artifact: true + disable_artifact: false project: lerobot notes: "" From b562f89c3b94c7b3b3ed4b991e6ae135cb6f896d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 11:42:45 +0000 Subject: [PATCH 5/8] update deps --- poetry.lock | 94 ++++++++++++++++++++++++++--------------------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/poetry.lock b/poetry.lock index 92449d45..d2d39e7a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -44,56 +44,56 @@ files = [ [[package]] name = "av" -version = "11.0.0" +version = "12.0.0" description = "Pythonic bindings for FFmpeg's libraries." optional = false python-versions = ">=3.8" files = [ - {file = "av-11.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a01f13b37eb6d181e03bbbbda29093fe2d68f10755795188220acdc89560ec27"}, - {file = "av-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b2236faee1b5d71dff3cdef81ef6eec22cc8b71dbfb45eb037e6437fe80f24e7"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40543a08e5c84aecd2bc84da5d43548743201897f0ba21bf5ae3a4dcddefca2b"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2907376884d956376aaf3bc1905fa4e0dcb9ba4e0d183e519392a19d89317d1b"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8d5581dcdc81cd601e3ce036809f14da82c46ff187bcefe981ec819390e0ab0"}, - {file = "av-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:150490f2a62cfa470f3cb60f3a0060ff93afd807e2b7b3b0eeeb5a992eb8d67b"}, - {file = "av-11.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d9bac0de62f09e2cb4e2132b5a46a89bc31c898189aa285b484c17351d991afe"}, - {file = "av-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2122ff8bdace4ce50207920f37de472517921e2ca1f0503464f748fdb8e20506"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:527d840697fee6ad4cf47eba987eaf30cd76bd96b2d20eaa907e166b9b8065c8"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abeaedddfca9101886eb6fc47318c5f5ece8480d330d73aacf6917d7421981a2"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13790fbb889b955baf885fe3761e923e85537ef414173465ec293177cedb7b99"}, - {file = "av-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:fc27e27f52480287f44226ad4ae3eb53346bf027959d0f00a9154530bd98b371"}, - {file = "av-11.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:892583e2c6b8c2500e5d24310f499caefcdaa2e48c8f7169ad41041aaaf4da11"}, - {file = "av-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6943679d70a9f4de974049e7ae2cf0b20afe0d7ddab650526c02a6cf9adcd08f"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6d73b038ccf1df5c16bc643eee5c694fb7732e09375e2f4903c1f4ce90dfb72"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c83422db3333e97b9680700df5185139352fc3a568b14179da3bdcbeb2f0e91b"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8413900f6a3639e0088c018a3a516a1656d4d16799e7aa759a16ddf3bd268e2b"}, - {file = "av-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:908e49ee336223801d8f2f7dca5a1deb64e9d8256138b8e7a79013b682a6ebb5"}, - {file = "av-11.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:82411ae4a562da07b76028d2f349fb0e6a86aa78ad2b18d2d7bf5b06b17fba14"}, - {file = "av-11.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:621104bd63e38fa4eca554da3722b1aac329619de39152f27eec8999acc72342"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:442878990c094455a16c10127edcc54bc4e78d355e6a13ad2a27608b0ecda38f"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:658199c92987dc72511f5ee8ade62faef6234b7a04c8b5788de99e366be5e073"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4b381665c49267b46f87297573898b85e5c41384750fee2e70267fbc4ba318"}, - {file = "av-11.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:60de14f71293e36ca4e297cc8a8460f0cf74f38a201694f3c6fc7f40301582f2"}, - {file = "av-11.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a90f04af96374dab94028a7471597bdfcf03083338b9be2eb8ca4805a8ec7ab5"}, - {file = "av-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8821ab2d23e4cb5c8abea6b08d2b1bfceca6af2d88fab1d1dc1b3ec7b34933c7"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a92342ed307eeaf9509a6b0f3bafd4337c4880c851b50acc18df48c625b63b6"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe3502975bc844f5d432c1f24d331bf6ef3e05532ebf06f7ed08b60719b8ea5"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c278b3a4fd111b4c9190abe6b1a5ca358d5f91e851d470b62577b957e0187b09"}, - {file = "av-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:478aa1d54fbc3058ea65ff41086b6adbe1326b456a027d2f3b59dbe60b4ac2ca"}, - {file = "av-11.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e8df10bb2d56a981d02a8a0b41491912b76dad06305d174a2575ef55ad451100"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b30c51e597785a89241bd61865faff2dbd3327856a8285a1e120dbf60e18348b"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8b8bd92edb096699b306e7b090ad096925ca3bdae6f89656f023fa2a2da627d"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9383af733abfc44f6fc29307a6c922fbf671ee343dc97b78b74eac6a2346a46d"}, - {file = "av-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a9df4a60579198b560f641cdfe4c2139948a70193ddc096b275f2cf6d94e3e04"}, - {file = "av-11.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ae5f7ae0a7093fb813686d4aa4c554531f80a28480427f5c155da51b747eff0"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50fb7d606f8236891d773c701d5650b93af8dbf78eeaac36fc7e1f7f64a9d664"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:543e0f9bf6ff02dedbe66d906fbc89c8907c80a8ea7413fc3fed68ce4a6e9b44"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:daa279c884457ab194ce78bdd89c0aa391af733da95fb3258d4c6eb8c258299a"}, - {file = "av-11.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1aacc21f4cf96447117a61edfb776afb73186750a5e08a21484ddfc3599aefb5"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2568b38eef777b916a5d02e42b8f67f92e12023531239ddd32e1ca4f3cdf8c5b"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:747c6d347e27c59cc2e78c9c505d23cd88eceff0cc9386be73693ae9009a577c"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bbd8f4941b9d3450eff40003b9b9d904667aec7ab085fa31f0f9bca32d755e0"}, - {file = "av-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f39c1244ba0cf185b2722aeec116b8a98a2ee5728ce687cec0bda60ee0360dfc"}, - {file = "av-11.0.0.tar.gz", hash = "sha256:48223f000a252070f8e700ff634bb7fb3aa1b7bc7e450373029fbdd6f369ac31"}, + {file = "av-12.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9d0890553951f76c479a9f2bb952aebae902b1c7d52feea614d37e1cd728a44"}, + {file = "av-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d7f229a253c2e3fea9682c09c5ae179bd6d5d2da38d89eb7f29ef7bed10cb2f"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61b3555d143aacf02e0446f6030319403538eba4dc713c18dfa653a2a23e7f9c"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607e13b2c2b26159a37525d7b6f647a32ce78711fccff23d146d3e255ffa115f"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f0b4cfb89f4f06b339c766f92648e798a96747d4163f2fa78660d1ab1f1b5e"}, + {file = "av-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:41dcb8c269fa58a56edf3a3c814c32a0c69586827f132b4e395a951b0ce14fad"}, + {file = "av-12.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa78fbe0e4469226512380180063116105048c66cb12e18ab4b518466c57e6c"}, + {file = "av-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60a869be1d6af916e65ea461cb93922f5db0698655ed7a7eae7c3ecd4af4debb"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df61811cc551c186f0a0e530d97b8b139453534d0f92c1790a923f666522ceda"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99cd2fc53091ebfb9a2fa9dd3580267f5bd1c040d0efd99fbc1a162576b271cb"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6d4f1e261df48932128e6495772faa4cc23f5dd1512eec73daab82ad9f3240"}, + {file = "av-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:6aec88e41a498b1e01e2dce5371557e20f9a51aae0c16decc5924ec0be2e22b6"}, + {file = "av-12.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90eb8f2d548e96cbc6f78e89c911cdb15a3d80fd944f31111660ce45939cd037"}, + {file = "av-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d7f3a02910e77d750dbd516256a16db15030e5371530ff5a5ae902dc03d9005d"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2477cc51526aa50575313d66e5e8ad7ab944588469be5e557b360ed572ae536"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a2f47149d3ca6deb79f3e515b8bef50e27ebdb160813e6d67dba77278d2a7883"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3306e4a3ce8b5bfcc3075793d4ed3a2df69179d8fba22cb944a6164dc235dfb6"}, + {file = "av-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:dc1b742e7f6df1b499fb960bd6697d1dd8e7ada7484a041a8c20e70a87225f53"}, + {file = "av-12.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0183be6889e835e1b074b4037bfce4fd44671c606cf1c4ab92ea2f271b544aec"}, + {file = "av-12.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:57337f20b208292ec8d3b11e4d289d8688a43d728174850a81b865d3253fff2c"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ec915e8f6521545a38566eefc281042ee504ea3cee0618d8558e4920588b3b2"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33ad5c0a23c45b72bd6bd47f3b2c1adcd2935ee3d0b6178ed66bba62b964ff31"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc3a652b12c93120514d56cf025da47442c5ba51530cdf7ba3660257dbb0de1"}, + {file = "av-12.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:037f793dd1ef4a1f57f090191a7f803ad10ec82da0d04ea26bbe0b8a145fe927"}, + {file = "av-12.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc532376aa264722fae55063abd1871d17a563dc895978e142c8ecfcdeb3a2e8"}, + {file = "av-12.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:abf0c4bc40a0af8a30f4cd96f3be6f19fbce0f21222d7fcec148e085127153f7"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81cedd1c072fbebf606724c406b1a1b00adc711f1dfd2bc04c633ce39d8439d8"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02d60f48be9f15dcda37d50f3ce8d7249d9a455643d4322dd3449986bacfc628"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d2619e4c26d661eecfc404f7d739d8b35f0dcef353fabe61512e030254b7031"}, + {file = "av-12.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:1892cc91c888d101777d5432d54e0554c11d1c3a2c65d02a2cae0a2256a8fbb9"}, + {file = "av-12.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4819e3ef6c3a44ef6f75907229133a1ee7f688245b2cf49b6b8e969a81ca72c9"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb16bb314cf1503b0250fc46b2c455ee196584231101be0123f4f78638227b62"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3e6a62bda9a1e144feeb59bbee046d7a2d98399634a30f57e4990197313c158"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08175ffbafa3a70c7b2f81083e160e34122a208cdf70f150b8f5d02c2de6965"}, + {file = "av-12.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e1d255be317b7c1ebdc4dae98935b9f3869161112dc829c625e54f90d8bdd7ab"}, + {file = "av-12.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17964b36e08435910aabd5b3f7dca12f99536902529767d276026bc08f94ced7"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2d5f78de29edee06ddcdd4c2b759914575492d6a0cd4de2ce31ee63a4953eff"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:309b32bc97158d0f0c19e273b8e17a855a86806b7194aebc23bd497326cff11f"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c409c71bd9c7c2f8d018c822f36b1447cfa96eca158381a96f3319bb0ff6e79e"}, + {file = "av-12.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:08fc5eaef60a257d622998626e233bf3ff90d2f817f6695d6a27e0ffcfe9dcff"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746ab0eff8a7a21a6c6d16e6b6e61709527eba2ad1a524d92a01bb60d02a3df7"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:013b3ac3de3aa1c137af0cedafd364fd1c7524ab3e1cd53e04564fd1632ac04d"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa55923527648f51ac005e44fe2797ebc67f53ad4850e0194d3753761ee33a2"}, + {file = "av-12.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:35d514f4dee0cf67e9e6b2a65fb4a28f98da88e71e8c7f7960bd04625d9fe965"}, + {file = "av-12.0.0.tar.gz", hash = "sha256:bcf21ebb722d4538b4099e5a78f730d78814dd70003511c185941dba5651b14d"}, ] [[package]] @@ -3115,7 +3115,7 @@ protobuf = ">=3.20" [[package]] name = "tensordict" -version = "0.4.0+ca4256e" +version = "0.4.0+b4c91e8" description = "" optional = false python-versions = "*" From 48df15ed26f81c88fc46e95b6e36ee271b4ec52d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 11:58:28 +0000 Subject: [PATCH 6/8] add cpu dep --- .github/poetry/cpu/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index fd7eb226..2f5a5422 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" dm = "^1.3" dm-control = "^1.0.16" +robomimic = "0.2.0" huggingface-hub = "^0.21.4" From 98361073efea84c6983d9c38d67bdffd7913aa84 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 12:02:24 +0000 Subject: [PATCH 7/8] cpu poetry lock --- .github/poetry/cpu/poetry.lock | 358 +++++++++++++++++++++++++-------- 1 file changed, 270 insertions(+), 88 deletions(-) diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index c07e3439..d558d505 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -44,56 +44,56 @@ files = [ [[package]] name = "av" -version = "11.0.0" +version = "12.0.0" description = "Pythonic bindings for FFmpeg's libraries." optional = false python-versions = ">=3.8" files = [ - {file = "av-11.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a01f13b37eb6d181e03bbbbda29093fe2d68f10755795188220acdc89560ec27"}, - {file = "av-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b2236faee1b5d71dff3cdef81ef6eec22cc8b71dbfb45eb037e6437fe80f24e7"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40543a08e5c84aecd2bc84da5d43548743201897f0ba21bf5ae3a4dcddefca2b"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2907376884d956376aaf3bc1905fa4e0dcb9ba4e0d183e519392a19d89317d1b"}, - {file = "av-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8d5581dcdc81cd601e3ce036809f14da82c46ff187bcefe981ec819390e0ab0"}, - {file = "av-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:150490f2a62cfa470f3cb60f3a0060ff93afd807e2b7b3b0eeeb5a992eb8d67b"}, - {file = "av-11.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d9bac0de62f09e2cb4e2132b5a46a89bc31c898189aa285b484c17351d991afe"}, - {file = "av-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2122ff8bdace4ce50207920f37de472517921e2ca1f0503464f748fdb8e20506"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:527d840697fee6ad4cf47eba987eaf30cd76bd96b2d20eaa907e166b9b8065c8"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abeaedddfca9101886eb6fc47318c5f5ece8480d330d73aacf6917d7421981a2"}, - {file = "av-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13790fbb889b955baf885fe3761e923e85537ef414173465ec293177cedb7b99"}, - {file = "av-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:fc27e27f52480287f44226ad4ae3eb53346bf027959d0f00a9154530bd98b371"}, - {file = "av-11.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:892583e2c6b8c2500e5d24310f499caefcdaa2e48c8f7169ad41041aaaf4da11"}, - {file = "av-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6943679d70a9f4de974049e7ae2cf0b20afe0d7ddab650526c02a6cf9adcd08f"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6d73b038ccf1df5c16bc643eee5c694fb7732e09375e2f4903c1f4ce90dfb72"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c83422db3333e97b9680700df5185139352fc3a568b14179da3bdcbeb2f0e91b"}, - {file = "av-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8413900f6a3639e0088c018a3a516a1656d4d16799e7aa759a16ddf3bd268e2b"}, - {file = "av-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:908e49ee336223801d8f2f7dca5a1deb64e9d8256138b8e7a79013b682a6ebb5"}, - {file = "av-11.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:82411ae4a562da07b76028d2f349fb0e6a86aa78ad2b18d2d7bf5b06b17fba14"}, - {file = "av-11.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:621104bd63e38fa4eca554da3722b1aac329619de39152f27eec8999acc72342"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:442878990c094455a16c10127edcc54bc4e78d355e6a13ad2a27608b0ecda38f"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:658199c92987dc72511f5ee8ade62faef6234b7a04c8b5788de99e366be5e073"}, - {file = "av-11.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4b381665c49267b46f87297573898b85e5c41384750fee2e70267fbc4ba318"}, - {file = "av-11.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:60de14f71293e36ca4e297cc8a8460f0cf74f38a201694f3c6fc7f40301582f2"}, - {file = "av-11.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a90f04af96374dab94028a7471597bdfcf03083338b9be2eb8ca4805a8ec7ab5"}, - {file = "av-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8821ab2d23e4cb5c8abea6b08d2b1bfceca6af2d88fab1d1dc1b3ec7b34933c7"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a92342ed307eeaf9509a6b0f3bafd4337c4880c851b50acc18df48c625b63b6"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe3502975bc844f5d432c1f24d331bf6ef3e05532ebf06f7ed08b60719b8ea5"}, - {file = "av-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c278b3a4fd111b4c9190abe6b1a5ca358d5f91e851d470b62577b957e0187b09"}, - {file = "av-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:478aa1d54fbc3058ea65ff41086b6adbe1326b456a027d2f3b59dbe60b4ac2ca"}, - {file = "av-11.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e8df10bb2d56a981d02a8a0b41491912b76dad06305d174a2575ef55ad451100"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b30c51e597785a89241bd61865faff2dbd3327856a8285a1e120dbf60e18348b"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8b8bd92edb096699b306e7b090ad096925ca3bdae6f89656f023fa2a2da627d"}, - {file = "av-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9383af733abfc44f6fc29307a6c922fbf671ee343dc97b78b74eac6a2346a46d"}, - {file = "av-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a9df4a60579198b560f641cdfe4c2139948a70193ddc096b275f2cf6d94e3e04"}, - {file = "av-11.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ae5f7ae0a7093fb813686d4aa4c554531f80a28480427f5c155da51b747eff0"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50fb7d606f8236891d773c701d5650b93af8dbf78eeaac36fc7e1f7f64a9d664"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:543e0f9bf6ff02dedbe66d906fbc89c8907c80a8ea7413fc3fed68ce4a6e9b44"}, - {file = "av-11.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:daa279c884457ab194ce78bdd89c0aa391af733da95fb3258d4c6eb8c258299a"}, - {file = "av-11.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1aacc21f4cf96447117a61edfb776afb73186750a5e08a21484ddfc3599aefb5"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2568b38eef777b916a5d02e42b8f67f92e12023531239ddd32e1ca4f3cdf8c5b"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:747c6d347e27c59cc2e78c9c505d23cd88eceff0cc9386be73693ae9009a577c"}, - {file = "av-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bbd8f4941b9d3450eff40003b9b9d904667aec7ab085fa31f0f9bca32d755e0"}, - {file = "av-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f39c1244ba0cf185b2722aeec116b8a98a2ee5728ce687cec0bda60ee0360dfc"}, - {file = "av-11.0.0.tar.gz", hash = "sha256:48223f000a252070f8e700ff634bb7fb3aa1b7bc7e450373029fbdd6f369ac31"}, + {file = "av-12.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9d0890553951f76c479a9f2bb952aebae902b1c7d52feea614d37e1cd728a44"}, + {file = "av-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d7f229a253c2e3fea9682c09c5ae179bd6d5d2da38d89eb7f29ef7bed10cb2f"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61b3555d143aacf02e0446f6030319403538eba4dc713c18dfa653a2a23e7f9c"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607e13b2c2b26159a37525d7b6f647a32ce78711fccff23d146d3e255ffa115f"}, + {file = "av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f0b4cfb89f4f06b339c766f92648e798a96747d4163f2fa78660d1ab1f1b5e"}, + {file = "av-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:41dcb8c269fa58a56edf3a3c814c32a0c69586827f132b4e395a951b0ce14fad"}, + {file = "av-12.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa78fbe0e4469226512380180063116105048c66cb12e18ab4b518466c57e6c"}, + {file = "av-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60a869be1d6af916e65ea461cb93922f5db0698655ed7a7eae7c3ecd4af4debb"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df61811cc551c186f0a0e530d97b8b139453534d0f92c1790a923f666522ceda"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99cd2fc53091ebfb9a2fa9dd3580267f5bd1c040d0efd99fbc1a162576b271cb"}, + {file = "av-12.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6d4f1e261df48932128e6495772faa4cc23f5dd1512eec73daab82ad9f3240"}, + {file = "av-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:6aec88e41a498b1e01e2dce5371557e20f9a51aae0c16decc5924ec0be2e22b6"}, + {file = "av-12.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90eb8f2d548e96cbc6f78e89c911cdb15a3d80fd944f31111660ce45939cd037"}, + {file = "av-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d7f3a02910e77d750dbd516256a16db15030e5371530ff5a5ae902dc03d9005d"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2477cc51526aa50575313d66e5e8ad7ab944588469be5e557b360ed572ae536"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a2f47149d3ca6deb79f3e515b8bef50e27ebdb160813e6d67dba77278d2a7883"}, + {file = "av-12.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3306e4a3ce8b5bfcc3075793d4ed3a2df69179d8fba22cb944a6164dc235dfb6"}, + {file = "av-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:dc1b742e7f6df1b499fb960bd6697d1dd8e7ada7484a041a8c20e70a87225f53"}, + {file = "av-12.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0183be6889e835e1b074b4037bfce4fd44671c606cf1c4ab92ea2f271b544aec"}, + {file = "av-12.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:57337f20b208292ec8d3b11e4d289d8688a43d728174850a81b865d3253fff2c"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ec915e8f6521545a38566eefc281042ee504ea3cee0618d8558e4920588b3b2"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33ad5c0a23c45b72bd6bd47f3b2c1adcd2935ee3d0b6178ed66bba62b964ff31"}, + {file = "av-12.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc3a652b12c93120514d56cf025da47442c5ba51530cdf7ba3660257dbb0de1"}, + {file = "av-12.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:037f793dd1ef4a1f57f090191a7f803ad10ec82da0d04ea26bbe0b8a145fe927"}, + {file = "av-12.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc532376aa264722fae55063abd1871d17a563dc895978e142c8ecfcdeb3a2e8"}, + {file = "av-12.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:abf0c4bc40a0af8a30f4cd96f3be6f19fbce0f21222d7fcec148e085127153f7"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81cedd1c072fbebf606724c406b1a1b00adc711f1dfd2bc04c633ce39d8439d8"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02d60f48be9f15dcda37d50f3ce8d7249d9a455643d4322dd3449986bacfc628"}, + {file = "av-12.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d2619e4c26d661eecfc404f7d739d8b35f0dcef353fabe61512e030254b7031"}, + {file = "av-12.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:1892cc91c888d101777d5432d54e0554c11d1c3a2c65d02a2cae0a2256a8fbb9"}, + {file = "av-12.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4819e3ef6c3a44ef6f75907229133a1ee7f688245b2cf49b6b8e969a81ca72c9"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb16bb314cf1503b0250fc46b2c455ee196584231101be0123f4f78638227b62"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3e6a62bda9a1e144feeb59bbee046d7a2d98399634a30f57e4990197313c158"}, + {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08175ffbafa3a70c7b2f81083e160e34122a208cdf70f150b8f5d02c2de6965"}, + {file = "av-12.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e1d255be317b7c1ebdc4dae98935b9f3869161112dc829c625e54f90d8bdd7ab"}, + {file = "av-12.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17964b36e08435910aabd5b3f7dca12f99536902529767d276026bc08f94ced7"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2d5f78de29edee06ddcdd4c2b759914575492d6a0cd4de2ce31ee63a4953eff"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:309b32bc97158d0f0c19e273b8e17a855a86806b7194aebc23bd497326cff11f"}, + {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c409c71bd9c7c2f8d018c822f36b1447cfa96eca158381a96f3319bb0ff6e79e"}, + {file = "av-12.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:08fc5eaef60a257d622998626e233bf3ff90d2f817f6695d6a27e0ffcfe9dcff"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746ab0eff8a7a21a6c6d16e6b6e61709527eba2ad1a524d92a01bb60d02a3df7"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:013b3ac3de3aa1c137af0cedafd364fd1c7524ab3e1cd53e04564fd1632ac04d"}, + {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa55923527648f51ac005e44fe2797ebc67f53ad4850e0194d3753761ee33a2"}, + {file = "av-12.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:35d514f4dee0cf67e9e6b2a65fb4a28f98da88e71e8c7f7960bd04625d9fe965"}, + {file = "av-12.0.0.tar.gz", hash = "sha256:bcf21ebb722d4538b4099e5a78f730d78814dd70003511c185941dba5651b14d"}, ] [[package]] @@ -614,6 +614,16 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "egl-probe" +version = "1.0.2" +description = "" +optional = false +python-versions = "*" +files = [ + {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, +] + [[package]] name = "einops" version = "0.7.0" @@ -705,13 +715,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -810,6 +820,72 @@ files = [ [package.extras] preview = ["glfw-preview"] +[[package]] +name = "grpcio" +version = "1.62.1" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, + {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, + {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, + {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, + {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, + {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, + {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, + {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, + {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, + {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, + {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, + {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, + {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, + {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, + {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, + {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, + {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, + {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, + {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, + {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, + {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, + {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, + {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, + {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.62.1)"] + [[package]] name = "gym" version = "0.26.2" @@ -1012,13 +1088,13 @@ setuptools = "*" [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -1027,17 +1103,17 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" -version = "6.3.0" +version = "6.3.2" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.3.0-py3-none-any.whl", hash = "sha256:783407aa1cd05550e3aa123e8f7cfaebee35ffa9cb0242919e2d1e4172222705"}, - {file = "importlib_resources-6.3.0.tar.gz", hash = "sha256:166072a97e86917a9025876f34286f549b9caf1d10b35a1b372bffa1600c6569"}, + {file = "importlib_resources-6.3.2-py3-none-any.whl", hash = "sha256:f41f4098b16cd140a97d256137cfd943d958219007990b2afb00439fc623f580"}, + {file = "importlib_resources-6.3.2.tar.gz", hash = "sha256:963eb79649252b0160c1afcfe5a1d3fe3ad66edd0a8b114beacffb70c0674223"}, ] [package.extras] @@ -1254,6 +1330,21 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.7)"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1459,32 +1550,32 @@ setuptools = "*" [[package]] name = "numba" -version = "0.59.0" +version = "0.59.1" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" files = [ - {file = "numba-0.59.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d061d800473fb8fef76a455221f4ad649a53f5e0f96e3f6c8b8553ee6fa98fa"}, - {file = "numba-0.59.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c086a434e7d3891ce5dfd3d1e7ee8102ac1e733962098578b507864120559ceb"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e20736bf62e61f8353fb71b0d3a1efba636c7a303d511600fc57648b55823ed"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e86e6786aec31d2002122199486e10bbc0dc40f78d76364cded375912b13614c"}, - {file = "numba-0.59.0-cp310-cp310-win_amd64.whl", hash = "sha256:0307ee91b24500bb7e64d8a109848baf3a3905df48ce142b8ac60aaa406a0400"}, - {file = "numba-0.59.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d540f69a8245fb714419c2209e9af6104e568eb97623adc8943642e61f5d6d8e"}, - {file = "numba-0.59.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1192d6b2906bf3ff72b1d97458724d98860ab86a91abdd4cfd9328432b661e31"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90efb436d3413809fcd15298c6d395cb7d98184350472588356ccf19db9e37c8"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd3dac45e25d927dcb65d44fb3a973994f5add2b15add13337844afe669dd1ba"}, - {file = "numba-0.59.0-cp311-cp311-win_amd64.whl", hash = "sha256:753dc601a159861808cc3207bad5c17724d3b69552fd22768fddbf302a817a4c"}, - {file = "numba-0.59.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ce62bc0e6dd5264e7ff7f34f41786889fa81a6b860662f824aa7532537a7bee0"}, - {file = "numba-0.59.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cbef55b73741b5eea2dbaf1b0590b14977ca95a13a07d200b794f8f6833a01c"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:70d26ba589f764be45ea8c272caa467dbe882b9676f6749fe6f42678091f5f21"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e125f7d69968118c28ec0eed9fbedd75440e64214b8d2eac033c22c04db48492"}, - {file = "numba-0.59.0-cp312-cp312-win_amd64.whl", hash = "sha256:4981659220b61a03c1e557654027d271f56f3087448967a55c79a0e5f926de62"}, - {file = "numba-0.59.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe4d7562d1eed754a7511ed7ba962067f198f86909741c5c6e18c4f1819b1f47"}, - {file = "numba-0.59.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6feb1504bb432280f900deaf4b1dadcee68812209500ed3f81c375cbceab24dc"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:944faad25ee23ea9dda582bfb0189fb9f4fc232359a80ab2a028b94c14ce2b1d"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5516a469514bfae52a9d7989db4940653a5cbfac106f44cb9c50133b7ad6224b"}, - {file = "numba-0.59.0-cp39-cp39-win_amd64.whl", hash = "sha256:32bd0a41525ec0b1b853da244808f4e5333867df3c43c30c33f89cf20b9c2b63"}, - {file = "numba-0.59.0.tar.gz", hash = "sha256:12b9b064a3e4ad00e2371fc5212ef0396c80f41caec9b5ec391c8b04b6eaf2a8"}, + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, ] [package.dependencies] @@ -2310,6 +2401,30 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "robomimic" +version = "0.2.0" +description = "robomimic: A Modular Framework for Robot Learning from Demonstration" +optional = false +python-versions = ">=3" +files = [ + {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, +] + +[package.dependencies] +egl_probe = ">=1.0.1" +h5py = "*" +imageio = "*" +imageio-ffmpeg = "*" +numpy = ">=1.13.3" +psutil = "*" +tensorboard = "*" +tensorboardX = "*" +termcolor = "*" +torch = "*" +torchvision = "*" +tqdm = "*" + [[package]] name = "safetensors" version = "0.4.2" @@ -2534,13 +2649,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.42.0" +version = "1.43.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, - {file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, + {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"}, + {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"}, ] [package.dependencies] @@ -2554,6 +2669,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -2798,9 +2914,58 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + +[[package]] +name = "tensorboardx" +version = "2.6.2.2" +description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" +optional = false +python-versions = "*" +files = [ + {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, + {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +protobuf = ">=3.20" + [[package]] name = "tensordict" -version = "0.4.0+6a56ecd" +version = "0.4.0+b4c91e8" description = "" optional = false python-versions = "*" @@ -2821,7 +2986,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures type = "git" url = "https://github.com/pytorch/tensordict" reference = "HEAD" -resolved_reference = "6a56ecd728757feee387f946b7da66dd452b739b" +resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" [[package]] name = "termcolor" @@ -3084,6 +3249,23 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "zarr" version = "2.17.1" @@ -3107,13 +3289,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] [[package]] name = "zipp" -version = "3.18.0" +version = "3.18.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.18.0-py3-none-any.whl", hash = "sha256:c1bb803ed69d2cce2373152797064f7e79bc43f0a3748eb494096a867e0ebf79"}, - {file = "zipp-3.18.0.tar.gz", hash = "sha256:df8d042b02765029a09b157efd8e820451045890acc30f8e37dd2f94a060221f"}, + {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, + {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, ] [package.extras] @@ -3123,4 +3305,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "4aa6a1e3f29560dd4a1c24d493ee1154089da4aa8d2190ad1f786c125ab2b735" +content-hash = "cbd9aedcb3a24417b85124fb94db706dd6ca0a90dfb610b0aebdcd3aa2a0333c" From 41912b962b78887e3336fe54dde6defe3b73e4ac Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 13:51:26 +0000 Subject: [PATCH 8/8] remove TODO --- lerobot/scripts/eval.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 55a2d3df..76deb2fe 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,7 +51,6 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 - # TODO now: generalize kwarg or maybe just remove it ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode():