From acf1174447ff2ffa8616280e0390c38020990738 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 21 Mar 2024 10:18:50 +0000 Subject: [PATCH] ready for review --- lerobot/common/envs/aloha/env.py | 2 +- lerobot/common/envs/factory.py | 24 --- lerobot/common/envs/pusht/env.py | 26 ++- lerobot/common/envs/pusht/pusht_image_env.py | 14 -- .../diffusion/diffusion_unet_image_policy.py | 41 ++++ .../model/multi_image_obs_encoder.py | 11 +- lerobot/common/policies/diffusion/policy.py | 13 +- lerobot/common/policies/factory.py | 20 +- lerobot/configs/policy/diffusion.yaml | 16 +- lerobot/scripts/eval.py | 1 + poetry.lock | 198 +++++++++++++++++- pyproject.toml | 1 + 12 files changed, 282 insertions(+), 85 deletions(-) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index e09564fbf..af2b354bc 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -192,7 +192,7 @@ class AlohaEnv(AbstractEnv): { "observation": TensorDict(obs, batch_size=[]), "reward": torch.tensor([reward], dtype=torch.float32), - # succes and done are true when coverage > self.success_threshold in env + # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), }, diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index e187d7131..06c7c43fb 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -62,27 +62,3 @@ def make_env(cfg, transform=None): {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) ], ) - - -# def make_env(env_name, frame_skip, device, is_test=False): -# env = GymEnv( -# env_name, -# frame_skip=frame_skip, -# from_pixels=True, -# pixels_only=False, -# device=device, -# ) -# env = TransformedEnv(env) -# env.append_transform(NoopResetEnv(noops=30, random=True)) -# if not is_test: -# env.append_transform(EndOfLifeTransform()) -# env.append_transform(RewardClipping(-1, 1)) -# env.append_transform(ToTensorImage()) -# env.append_transform(GrayScale()) -# env.append_transform(Resize(84, 84)) -# env.append_transform(CatFrames(N=4, dim=-3)) -# env.append_transform(RewardSum()) -# env.append_transform(StepCounter(max_steps=4500)) -# env.append_transform(DoubleToFloat()) -# env.append_transform(VecNorm(in_keys=["pixels"])) -# return env diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index f440d443d..3824a5d22 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -3,6 +3,8 @@ import logging from collections import deque from typing import Optional +import cv2 +import numpy as np import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv): self._env = PushTImageEnv(render_size=self.image_size) - def render(self, mode="rgb_array", width=384, height=384): + def render(self, mode="rgb_array", width=96, height=96, with_marker=True): + """ + with_marker adds a cursor showing the targeted action for the controller. + """ if width != height: raise NotImplementedError() tmp = self._env.render_size - self._env.render_size = width - out = self._env.render(mode) + if width != self._env.render_size: + self._env.render_cache = None + self._env.render_size = width + out = self._env.render(mode).copy() + if with_marker and self._env.latest_action is not None: + action = np.array(self._env.latest_action) + coord = (action / 512 * self._env.render_size).astype(np.int32) + marker_size = int(8 / 96 * self._env.render_size) + thickness = int(1 / 96 * self._env.render_size) + cv2.drawMarker( + out, + coord, + color=(255, 0, 0), + markerType=cv2.MARKER_CROSS, + markerSize=marker_size, + thickness=thickness, + ) self._env.render_size = tmp return out diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index b30ad874a..ec8e177b4 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -27,20 +27,6 @@ class PushTImageEnv(PushTEnv): img_obs = np.moveaxis(img, -1, 0) obs = {"image": img_obs, "agent_pos": agent_pos} - # draw action - if self.latest_action is not None: - action = np.array(self.latest_action) - coord = (action / 512 * 96).astype(np.int32) - marker_size = int(8 / 96 * self.render_size) - thickness = int(1 / 96 * self.render_size) - # cv2.drawMarker( - # img, - # coord, - # color=(255, 0, 0), - # markerType=cv2.MARKER_CROSS, - # markerSize=marker_size, - # thickness=thickness, - # ) self.render_cache = img return obs diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index c5b00d945..7719fddea 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -1,3 +1,44 @@ +"""Code from the original diffusion policy project. + +Notes on how to load a checkpoint from the original repository: + +In the original repository, run the eval and use a breakpoint to extract the policy weights. + +``` +torch.save(policy.state_dict(), "weights.pt") +``` + +In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights: + +``` +loaded = torch.load("weights.pt") +aligned = {} +their_prefix = "obs_encoder.obs_nets.image.backbone" +our_prefix = "obs_encoder.key_model_map.image.backbone" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +their_prefix = "obs_encoder.obs_nets.image.pool" +our_prefix = "obs_encoder.key_model_map.image.pool" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +their_prefix = "obs_encoder.obs_nets.image.nets.3" +our_prefix = "obs_encoder.key_model_map.image.out" +aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) +aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) +# Note: here you are loading into the ema model. +missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False) +assert all('_dummy_variable' in k for k in missing_keys) +assert len(unexpected_keys) == 0 +``` + +Then in that same runtime you can also save the weights with the new aligned state_dict: + +``` +policy.save("weights.pt") +``` + +Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. + +""" + from typing import Dict import torch diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index c7b9807df..d724cd493 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,11 +1,10 @@ import copy from typing import Dict, Optional, Tuple, Union -import timm import torch import torch.nn as nn import torchvision -from robomimic.models.base_nets import SpatialSoftmax +from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin @@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules class RgbEncoder(nn.Module): """Following `VisualCore` from Robomimic 0.2.0.""" - def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32): + def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32): """ input_shape: channel-first input shape (C, H, W) resnet_name: a timm model name. pretrained: whether to use timm pretrained weights. - rele: whether to use relu as a final step. + relu: whether to use relu as a final step. num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). """ super().__init__() - self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") - # self.backbone = ResNet18Conv(input_channel=input_shape[0]) + self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained) # Figure out the feature map shape. with torch.inference_mode(): feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) @@ -34,7 +32,6 @@ class RgbEncoder(nn.Module): self.relu = nn.ReLU() if relu else nn.Identity() def forward(self, x): - # TODO(now): make nonlinearity optional return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a4185afc1..1b3b24b67 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -5,7 +5,6 @@ import time import hydra import torch -from lerobot.common.ema import update_ema_parameters from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler @@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy): cfg_rgb_model, cfg_obs_encoder, cfg_optimizer, + cfg_ema, shape_meta: dict, horizon, n_action_steps, @@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy): self.diffusion.cuda() self.ema_diffusion = None - if self.cfg.ema.enable: + self.ema = None + if self.cfg.use_ema: self.ema_diffusion = copy.deepcopy(self.diffusion) + self.ema = hydra.utils.instantiate( + cfg_ema, + model=self.ema_diffusion, + ) self.optimizer = hydra.utils.instantiate( cfg_optimizer, @@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy): self.optimizer.zero_grad() self.lr_scheduler.step() - if self.cfg.ema.enable: - update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate) + if self.ema is not None: + self.ema.step(self.diffusion) info = { "loss": loss.item(), diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 32a366b30..085baab58 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,6 +16,7 @@ def make_policy(cfg): cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) @@ -39,23 +40,4 @@ def make_policy(cfg): raise NotImplementedError() policy.load(cfg.policy.pretrained_model_path) - # import torch - # loaded = torch.load('/home/alexander/Downloads/dp.pth') - # aligned = {} - - # their_prefix = "obs_encoder.obs_nets.image.backbone" - # our_prefix = "obs_encoder.key_model_map.image.backbone" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - # their_prefix = "obs_encoder.obs_nets.image.pool" - # our_prefix = "obs_encoder.key_model_map.image.pool" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - # their_prefix = "obs_encoder.obs_nets.image.nets.3" - # our_prefix = "obs_encoder.key_model_map.image.out" - # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) - - # aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) - # missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False) - # assert all('_dummy_variable' in k for k in missing_keys) - # assert len(unexpected_keys) == 0 - return policy diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index a81952e0d..acb368ed0 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -12,6 +12,7 @@ shape_meta: action: shape: [2] +seed: 100000 horizon: 16 n_obs_steps: 2 n_action_steps: 8 @@ -26,7 +27,7 @@ eval_freq: 5000 save_freq: 5000 log_freq: 250 -offline_steps: 50000 +offline_steps: 200000 online_steps: 0 offline_prioritized_sampler: true @@ -58,9 +59,7 @@ policy: balanced_sampling: false utd: 1 offline_steps: ${offline_steps} - ema: - enable: true - rate: 0.999 + use_ema: true lr_scheduler: cosine lr_warmup_steps: 500 grad_clip_norm: 10 @@ -86,11 +85,18 @@ obs_encoder: norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: - model_name: resnet18 pretrained: false num_keypoints: 32 relu: true +ema: + _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + optimizer: _target_: torch.optim.AdamW lr: 1.0e-4 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 41d58b914..c0c34629c 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -50,6 +50,7 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 + # TODO now: generalize kwarg or maybe just remove it ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): diff --git a/poetry.lock b/poetry.lock index ddb0a0e31..92449d451 100644 --- a/poetry.lock +++ b/poetry.lock @@ -604,6 +604,16 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "egl-probe" +version = "1.0.2" +description = "" +optional = false +python-versions = "*" +files = [ + {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, +] + [[package]] name = "einops" version = "0.7.0" @@ -763,6 +773,72 @@ files = [ [package.extras] preview = ["glfw-preview"] +[[package]] +name = "grpcio" +version = "1.62.1" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, + {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, + {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, + {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, + {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, + {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, + {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, + {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, + {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, + {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, + {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, + {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, + {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, + {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, + {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, + {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, + {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, + {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, + {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, + {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, + {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, + {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, + {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, + {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.62.1)"] + [[package]] name = "gym" version = "0.26.2" @@ -1038,13 +1114,13 @@ setuptools = "*" [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -1053,7 +1129,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -1265,6 +1341,21 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.7)"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2460,6 +2551,30 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "robomimic" +version = "0.2.0" +description = "robomimic: A Modular Framework for Robot Learning from Demonstration" +optional = false +python-versions = ">=3" +files = [ + {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, +] + +[package.dependencies] +egl_probe = ">=1.0.1" +h5py = "*" +imageio = "*" +imageio-ffmpeg = "*" +numpy = ">=1.13.3" +psutil = "*" +tensorboard = "*" +tensorboardX = "*" +termcolor = "*" +torch = "*" +torchvision = "*" +tqdm = "*" + [[package]] name = "safetensors" version = "0.4.2" @@ -2684,13 +2799,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.42.0" +version = "1.43.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, - {file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, + {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"}, + {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"}, ] [package.dependencies] @@ -2704,6 +2819,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -2948,6 +3064,55 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + +[[package]] +name = "tensorboardx" +version = "2.6.2.2" +description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" +optional = false +python-versions = "*" +files = [ + {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, + {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +protobuf = ">=3.20" + [[package]] name = "tensordict" version = "0.4.0+ca4256e" @@ -3289,6 +3454,23 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "zarr" version = "2.17.1" @@ -3328,4 +3510,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9" +content-hash = "1a45c808e1c48bcbf4319d4cf6876771b7d50f40a5a8968a8b7f3af36192bf34" diff --git a/pyproject.toml b/pyproject.toml index 2e818a440..7e9996a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ torchvision = "^0.17.1" h5py = "^3.10.0" dm-control = "1.0.14" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} +robomimic = "0.2.0" [tool.poetry.group.dev.dependencies]