ready for review
This commit is contained in:
@@ -3,6 +3,8 @@ import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv):
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
|
||||
"""
|
||||
with_marker adds a cursor showing the targeted action for the controller.
|
||||
"""
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
tmp = self._env.render_size
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode)
|
||||
if width != self._env.render_size:
|
||||
self._env.render_cache = None
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode).copy()
|
||||
if with_marker and self._env.latest_action is not None:
|
||||
action = np.array(self._env.latest_action)
|
||||
coord = (action / 512 * self._env.render_size).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self._env.render_size)
|
||||
thickness = int(1 / 96 * self._env.render_size)
|
||||
cv2.drawMarker(
|
||||
out,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
self._env.render_size = tmp
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user