Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -1,7 +1,6 @@
import importlib
from typing import Optional
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
class PushtEnv(EnvBase):
def __init__(
self,
frame_skip: int = 1,
@@ -46,8 +44,6 @@ class PushtEnv(EnvBase):
if not _has_gym:
raise ImportError("Cannot import gym.")
from diffusion_policy.env.pusht.pusht_env import PushTEnv
if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
@@ -71,14 +67,10 @@ class PushtEnv(EnvBase):
obs = {"image": torch.from_numpy(raw_obs["image"])}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
torch.float32
)
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
else:
# TODO:
obs = {
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
}
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
@@ -109,7 +101,7 @@ class PushtEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for t in range(self.frame_skip):
for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward

View File

@@ -15,12 +15,13 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed
MAX_NUM_ACTIONS = 4
_has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(EnvBase):
def __init__(
self,
task,
@@ -52,18 +53,13 @@ class SimxarmEnv(EnvBase):
from simxarm import TASKS
if self.task not in TASKS:
raise ValueError(
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
)
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
self._env = TASKS[self.task]["env"]()
MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros(
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
)
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
@@ -75,9 +71,7 @@ class SimxarmEnv(EnvBase):
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = self.render(
mode="rgb_array", width=self.image_size, height=self.image_size
)
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
image = torch.tensor(image.copy(), dtype=torch.uint8)
@@ -114,7 +108,7 @@ class SimxarmEnv(EnvBase):
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for t in range(self.frame_skip):
for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward

View File

@@ -5,7 +5,6 @@ from torchrl.envs.transforms import ObservationTransform
class Prod(ObservationTransform):
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys