ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -192,7 +192,7 @@ class AlohaEnv(AbstractEnv):
{ {
"observation": TensorDict(obs, batch_size=[]), "observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([reward], dtype=torch.float32), "reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env # success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool), "done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool),
}, },

View File

@@ -62,27 +62,3 @@ def make_env(cfg, transform=None):
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
], ],
) )
# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env

View File

@@ -3,6 +3,8 @@ import logging
from collections import deque from collections import deque
from typing import Optional from typing import Optional
import cv2
import numpy as np
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv):
self._env = PushTImageEnv(render_size=self.image_size) self._env = PushTImageEnv(render_size=self.image_size)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
"""
with_marker adds a cursor showing the targeted action for the controller.
"""
if width != height: if width != height:
raise NotImplementedError() raise NotImplementedError()
tmp = self._env.render_size tmp = self._env.render_size
self._env.render_size = width if width != self._env.render_size:
out = self._env.render(mode) self._env.render_cache = None
self._env.render_size = width
out = self._env.render(mode).copy()
if with_marker and self._env.latest_action is not None:
action = np.array(self._env.latest_action)
coord = (action / 512 * self._env.render_size).astype(np.int32)
marker_size = int(8 / 96 * self._env.render_size)
thickness = int(1 / 96 * self._env.render_size)
cv2.drawMarker(
out,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self._env.render_size = tmp self._env.render_size = tmp
return out return out

View File

@@ -27,20 +27,6 @@ class PushTImageEnv(PushTEnv):
img_obs = np.moveaxis(img, -1, 0) img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos} obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
# cv2.drawMarker(
# img,
# coord,
# color=(255, 0, 0),
# markerType=cv2.MARKER_CROSS,
# markerSize=marker_size,
# thickness=thickness,
# )
self.render_cache = img self.render_cache = img
return obs return obs

View File

@@ -1,3 +1,44 @@
"""Code from the original diffusion policy project.
Notes on how to load a checkpoint from the original repository:
In the original repository, run the eval and use a breakpoint to extract the policy weights.
```
torch.save(policy.state_dict(), "weights.pt")
```
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
```
loaded = torch.load("weights.pt")
aligned = {}
their_prefix = "obs_encoder.obs_nets.image.backbone"
our_prefix = "obs_encoder.key_model_map.image.backbone"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.pool"
our_prefix = "obs_encoder.key_model_map.image.pool"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.nets.3"
our_prefix = "obs_encoder.key_model_map.image.out"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# Note: here you are loading into the ema model.
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
assert all('_dummy_variable' in k for k in missing_keys)
assert len(unexpected_keys) == 0
```
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
"""
from typing import Dict from typing import Dict
import torch import torch

View File

@@ -1,11 +1,10 @@
import copy import copy
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import timm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from robomimic.models.base_nets import SpatialSoftmax from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
@@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module): class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0.""" """Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32): def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
""" """
input_shape: channel-first input shape (C, H, W) input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name. resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights. pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step. relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
""" """
super().__init__() super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
# Figure out the feature map shape. # Figure out the feature map shape.
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
@@ -34,7 +32,6 @@ class RgbEncoder(nn.Module):
self.relu = nn.ReLU() if relu else nn.Identity() self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x): def forward(self, x):
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))

View File

@@ -5,7 +5,6 @@ import time
import hydra import hydra
import torch import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model, cfg_rgb_model,
cfg_obs_encoder, cfg_obs_encoder,
cfg_optimizer, cfg_optimizer,
cfg_ema,
shape_meta: dict, shape_meta: dict,
horizon, horizon,
n_action_steps, n_action_steps,
@@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy):
self.diffusion.cuda() self.diffusion.cuda()
self.ema_diffusion = None self.ema_diffusion = None
if self.cfg.ema.enable: self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion) self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate( self.optimizer = hydra.utils.instantiate(
cfg_optimizer, cfg_optimizer,
@@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
if self.cfg.ema.enable: if self.ema is not None:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate) self.ema.step(self.diffusion)
info = { info = {
"loss": loss.item(), "loss": loss.item(),

View File

@@ -16,6 +16,7 @@ def make_policy(cfg):
cfg_rgb_model=cfg.rgb_model, cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy, **cfg.policy,
) )
@@ -39,23 +40,4 @@ def make_policy(cfg):
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path) policy.load(cfg.policy.pretrained_model_path)
# import torch
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
# aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone"
# our_prefix = "obs_encoder.key_model_map.image.backbone"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.pool"
# our_prefix = "obs_encoder.key_model_map.image.pool"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
# our_prefix = "obs_encoder.key_model_map.image.out"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
# assert all('_dummy_variable' in k for k in missing_keys)
# assert len(unexpected_keys) == 0
return policy return policy

View File

@@ -12,6 +12,7 @@ shape_meta:
action: action:
shape: [2] shape: [2]
seed: 100000
horizon: 16 horizon: 16
n_obs_steps: 2 n_obs_steps: 2
n_action_steps: 8 n_action_steps: 8
@@ -26,7 +27,7 @@ eval_freq: 5000
save_freq: 5000 save_freq: 5000
log_freq: 250 log_freq: 250
offline_steps: 50000 offline_steps: 200000
online_steps: 0 online_steps: 0
offline_prioritized_sampler: true offline_prioritized_sampler: true
@@ -58,9 +59,7 @@ policy:
balanced_sampling: false balanced_sampling: false
utd: 1 utd: 1
offline_steps: ${offline_steps} offline_steps: ${offline_steps}
ema: use_ema: true
enable: true
rate: 0.999
lr_scheduler: cosine lr_scheduler: cosine
lr_warmup_steps: 500 lr_warmup_steps: 500
grad_clip_norm: 10 grad_clip_norm: 10
@@ -86,11 +85,18 @@ obs_encoder:
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model: rgb_model:
model_name: resnet18
pretrained: false pretrained: false
num_keypoints: 32 num_keypoints: 32
relu: true relu: true
ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
optimizer: optimizer:
_target_: torch.optim.AdamW _target_: torch.optim.AdamW
lr: 1.0e-4 lr: 1.0e-4

View File

@@ -50,6 +50,7 @@ def eval_policy(
def maybe_render_frame(env: EnvBase, _): def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023 if save_video or (return_first_video and i == 0): # noqa: B023
# TODO now: generalize kwarg or maybe just remove it
ep_frames.append(env.render()) # noqa: B023 ep_frames.append(env.render()) # noqa: B023
with torch.inference_mode(): with torch.inference_mode():

198
poetry.lock generated
View File

@@ -604,6 +604,16 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.4.0" six = ">=1.4.0"
[[package]]
name = "egl-probe"
version = "1.0.2"
description = ""
optional = false
python-versions = "*"
files = [
{file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"},
]
[[package]] [[package]]
name = "einops" name = "einops"
version = "0.7.0" version = "0.7.0"
@@ -763,6 +773,72 @@ files = [
[package.extras] [package.extras]
preview = ["glfw-preview"] preview = ["glfw-preview"]
[[package]]
name = "grpcio"
version = "1.62.1"
description = "HTTP/2-based RPC framework"
optional = false
python-versions = ">=3.7"
files = [
{file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"},
{file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"},
{file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"},
{file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"},
{file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"},
{file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"},
{file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"},
{file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"},
{file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"},
{file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"},
{file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"},
{file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"},
{file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"},
{file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"},
{file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"},
{file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"},
{file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"},
{file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"},
{file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"},
{file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"},
{file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"},
{file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"},
{file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"},
{file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"},
{file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"},
{file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"},
{file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"},
{file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"},
{file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"},
{file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"},
{file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"},
{file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"},
{file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"},
{file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"},
{file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"},
{file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"},
]
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]] [[package]]
name = "gym" name = "gym"
version = "0.26.2" version = "0.26.2"
@@ -1038,13 +1114,13 @@ setuptools = "*"
[[package]] [[package]]
name = "importlib-metadata" name = "importlib-metadata"
version = "7.0.2" version = "7.1.0"
description = "Read metadata from Python packages" description = "Read metadata from Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"},
{file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"},
] ]
[package.dependencies] [package.dependencies]
@@ -1053,7 +1129,7 @@ zipp = ">=0.5"
[package.extras] [package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"] perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
@@ -1265,6 +1341,21 @@ html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"] htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=3.0.7)"] source = ["Cython (>=3.0.7)"]
[[package]]
name = "markdown"
version = "3.6"
description = "Python implementation of John Gruber's Markdown."
optional = false
python-versions = ">=3.8"
files = [
{file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"},
{file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"},
]
[package.extras]
docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
testing = ["coverage", "pyyaml"]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "2.1.5" version = "2.1.5"
@@ -2460,6 +2551,30 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "robomimic"
version = "0.2.0"
description = "robomimic: A Modular Framework for Robot Learning from Demonstration"
optional = false
python-versions = ">=3"
files = [
{file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"},
]
[package.dependencies]
egl_probe = ">=1.0.1"
h5py = "*"
imageio = "*"
imageio-ffmpeg = "*"
numpy = ">=1.13.3"
psutil = "*"
tensorboard = "*"
tensorboardX = "*"
termcolor = "*"
torch = "*"
torchvision = "*"
tqdm = "*"
[[package]] [[package]]
name = "safetensors" name = "safetensors"
version = "0.4.2" version = "0.4.2"
@@ -2684,13 +2799,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
[[package]] [[package]]
name = "sentry-sdk" name = "sentry-sdk"
version = "1.42.0" version = "1.43.0"
description = "Python client for Sentry (https://sentry.io)" description = "Python client for Sentry (https://sentry.io)"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"},
{file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"},
] ]
[package.dependencies] [package.dependencies]
@@ -2704,6 +2819,7 @@ asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"] beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"] bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"] celery = ["celery (>=3)"]
celery-redbeat = ["celery-redbeat (>=2)"]
chalice = ["chalice (>=1.16.0)"] chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"] django = ["django (>=1.8)"]
@@ -2948,6 +3064,55 @@ files = [
[package.dependencies] [package.dependencies]
mpmath = ">=0.19" mpmath = ">=0.19"
[[package]]
name = "tensorboard"
version = "2.16.2"
description = "TensorBoard lets you watch Tensors Flow"
optional = false
python-versions = ">=3.9"
files = [
{file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"},
]
[package.dependencies]
absl-py = ">=0.4"
grpcio = ">=1.48.2"
markdown = ">=2.6.8"
numpy = ">=1.12.0"
protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
setuptools = ">=41.0.0"
six = ">1.9"
tensorboard-data-server = ">=0.7.0,<0.8.0"
werkzeug = ">=1.0.1"
[[package]]
name = "tensorboard-data-server"
version = "0.7.2"
description = "Fast data loading for TensorBoard"
optional = false
python-versions = ">=3.7"
files = [
{file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
{file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
{file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
]
[[package]]
name = "tensorboardx"
version = "2.6.2.2"
description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
optional = false
python-versions = "*"
files = [
{file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"},
{file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"},
]
[package.dependencies]
numpy = "*"
packaging = "*"
protobuf = ">=3.20"
[[package]] [[package]]
name = "tensordict" name = "tensordict"
version = "0.4.0+ca4256e" version = "0.4.0+ca4256e"
@@ -3289,6 +3454,23 @@ perf = ["orjson"]
reports = ["pydantic (>=2.0.0)"] reports = ["pydantic (>=2.0.0)"]
sweeps = ["sweeps (>=0.2.0)"] sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "werkzeug"
version = "3.0.1"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
{file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"},
{file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"},
]
[package.dependencies]
MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]] [[package]]
name = "zarr" name = "zarr"
version = "2.17.1" version = "2.17.1"
@@ -3328,4 +3510,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9" content-hash = "1a45c808e1c48bcbf4319d4cf6876771b7d50f40a5a8968a8b7f3af36192bf34"

View File

@@ -51,6 +51,7 @@ torchvision = "^0.17.1"
h5py = "^3.10.0" h5py = "^3.10.0"
dm-control = "1.0.14" dm-control = "1.0.14"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
robomimic = "0.2.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]