Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-04-08 10:47:40 +02:00
13 changed files with 1173 additions and 14 deletions

View File

@@ -13,7 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
@@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
policy_features[policy_key] = feature
return policy_features
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
first_type = type(env.envs[0]) # Get type of first env
return all(type(e) is first_type for e in env.envs) # Fast type check
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning,
stacklevel=2,
)
if not are_all_envs_same_type(env):
warnings.warn(
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
UserWarning,
stacklevel=2,
)
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation