Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -1,7 +1,6 @@
import importlib
from typing import Optional
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
class PushtEnv(EnvBase):
def __init__(
self,
frame_skip: int = 1,
@@ -46,8 +44,6 @@ class PushtEnv(EnvBase):
if not _has_gym:
raise ImportError("Cannot import gym.")
from diffusion_policy.env.pusht.pusht_env import PushTEnv
if not from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
@@ -71,14 +67,10 @@ class PushtEnv(EnvBase):
obs = {"image": torch.from_numpy(raw_obs["image"])}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
torch.float32
)
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
else:
# TODO:
obs = {
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
}
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
@@ -109,7 +101,7 @@ class PushtEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for t in range(self.frame_skip):
for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward