Eval reproduced! Train running (but not reproduced)

This commit is contained in:
Cadene
2024-02-10 15:46:24 +00:00
parent 937b2f8cba
commit 228c045674
14 changed files with 787 additions and 118 deletions

View File

@@ -7,6 +7,7 @@ def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
frame_skip=cfg.action_repeat,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,

View File

@@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
@@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase):
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
@@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
raw_obs, reward, done, info = self._env.step(action)
sum_reward = 0
for t in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([reward], dtype=torch.float32),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},