Ran pre-commit run --all-files
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_skip: int = 1,
|
||||
@@ -46,8 +44,6 @@ class PushtEnv(EnvBase):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
@@ -71,14 +67,10 @@ class PushtEnv(EnvBase):
|
||||
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(
|
||||
torch.float32
|
||||
)
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {
|
||||
"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)
|
||||
}
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
@@ -109,7 +101,7 @@ class PushtEnv(EnvBase):
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for t in range(self.frame_skip):
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
|
||||
Reference in New Issue
Block a user