Making Envs module pass MyPy checks (#2048)
* Fix configs.py None MyPy error * Use img_tensor instead of img in utils.py * Add type assertion in factory.py * Resolve merge conflict * Uncomment envs moodule for mypy checks in pyproject.toml --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -86,12 +86,12 @@ repos:
|
|||||||
|
|
||||||
# TODO(Steven): Uncomment when ready to use
|
# TODO(Steven): Uncomment when ready to use
|
||||||
##### Static Analysis & Typing #####
|
##### Static Analysis & Typing #####
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
# rev: v1.16.0
|
rev: v1.16.0
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# args: [--config-file=pyproject.toml]
|
args: [--config-file=pyproject.toml]
|
||||||
# exclude: ^(examples|benchmarks|tests)/
|
exclude: ^(examples|benchmarks|tests)/
|
||||||
|
|
||||||
##### Docstring Checks #####
|
##### Docstring Checks #####
|
||||||
# - repo: https://github.com/akaihola/darglint2
|
# - repo: https://github.com/akaihola/darglint2
|
||||||
|
|||||||
@@ -270,10 +270,10 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# TODO: Enable mypy gradually module by module across multiple PRs
|
# TODO: Enable mypy gradually module by module across multiple PRs
|
||||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||||
|
|
||||||
# [tool.mypy]
|
[tool.mypy]
|
||||||
# python_version = "3.10"
|
python_version = "3.10"
|
||||||
# ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
# follow_imports = "skip"
|
follow_imports = "skip"
|
||||||
# warn_return_any = true
|
# warn_return_any = true
|
||||||
# warn_unused_configs = true
|
# warn_unused_configs = true
|
||||||
# strict = true
|
# strict = true
|
||||||
@@ -281,14 +281,14 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# disallow_incomplete_defs = true
|
# disallow_incomplete_defs = true
|
||||||
# check_untyped_defs = true
|
# check_untyped_defs = true
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.*"
|
module = "lerobot.*"
|
||||||
# ignore_errors = true
|
ignore_errors = true
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.envs.*"
|
module = "lerobot.envs.*"
|
||||||
# # Enable type checking only for the envs module
|
# Enable type checking only for the envs module
|
||||||
# ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
@@ -299,7 +299,6 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# module = "lerobot.configs.*"
|
# module = "lerobot.configs.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
# PHASE 2: Core modules
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.optim.*"
|
# module = "lerobot.optim.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
@@ -340,6 +339,7 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# module = "lerobot.rl.*"
|
# module = "lerobot.rl.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.async_inference.*"
|
# module = "lerobot.async_inference.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class LiberoEnv(EnvConfig):
|
|||||||
render_mode: str = "rgb_array"
|
render_mode: str = "rgb_array"
|
||||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||||
init_states: bool = True
|
init_states: bool = True
|
||||||
camera_name_mapping: dict[str, str] | None = (None,)
|
camera_name_mapping: dict[str, str] | None = None
|
||||||
features: dict[str, PolicyFeature] = field(
|
features: dict[str, PolicyFeature] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ def make_env(
|
|||||||
if "libero" in cfg.type:
|
if "libero" in cfg.type:
|
||||||
from lerobot.envs.libero import create_libero_envs
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
|
if cfg.task is None:
|
||||||
|
raise ValueError("LiberoEnv requires a task to be specified")
|
||||||
|
|
||||||
return create_libero_envs(
|
return create_libero_envs(
|
||||||
task=cfg.task,
|
task=cfg.task,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
|
|||||||
@@ -48,25 +48,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
img = torch.from_numpy(img)
|
img_tensor = torch.from_numpy(img)
|
||||||
|
|
||||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||||
if img.ndim == 3:
|
if img_tensor.ndim == 3:
|
||||||
img = img.unsqueeze(0)
|
img_tensor = img_tensor.unsqueeze(0)
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img_tensor.shape
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}"
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}"
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
||||||
img = img.type(torch.float32)
|
img_tensor = img_tensor.type(torch.float32)
|
||||||
img /= 255
|
img_tensor /= 255
|
||||||
|
|
||||||
return_observations[imgkey] = img
|
return_observations[imgkey] = img_tensor
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||||
|
|||||||
Reference in New Issue
Block a user