Ran pre-commit run --all-files
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_skip: int = 1,
|
||||
@@ -46,8 +44,6 @@ class PushtEnv(EnvBase):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
@@ -71,14 +67,10 @@ class PushtEnv(EnvBase):
|
||||
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
|
||||
torch.float32
|
||||
)
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {
|
||||
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
|
||||
}
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
@@ -109,7 +101,7 @@ class PushtEnv(EnvBase):
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
|
||||
@@ -15,12 +15,13 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
@@ -52,18 +53,13 @@ class SimxarmEnv(EnvBase):
|
||||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(
|
||||
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
|
||||
)
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros(
|
||||
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
|
||||
)
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
@@ -75,9 +71,7 @@ class SimxarmEnv(EnvBase):
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = self.render(
|
||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||
)
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
@@ -114,7 +108,7 @@ class SimxarmEnv(EnvBase):
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from torchrl.envs.transforms import ObservationTransform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
|
||||
Reference in New Issue
Block a user