[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
8eb3c1510c
commit
eb44a06a9b
@@ -47,7 +47,7 @@ from lerobot.scripts.server.buffer import (
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
@@ -444,7 +444,7 @@ def receive_policy(
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file)
|
||||
logging.info(f"Actor receive policy process logging initialized")
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
|
||||
@@ -515,7 +515,7 @@ class ReplayBuffer:
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
|
||||
@@ -701,10 +701,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -5,14 +5,15 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
from lerobot.configs import parser
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
follower_port = "/dev/tty.usbmodem58760431631"
|
||||
leader_port = "/dev/tty.usbmodem58760433331"
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=30,
|
||||
@@ -85,21 +86,22 @@ def find_ee_bounds(
|
||||
def make_robot(robot_type="so100"):
|
||||
"""
|
||||
Create a robot instance using the appropriate robot config class.
|
||||
|
||||
|
||||
Args:
|
||||
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
||||
|
||||
|
||||
Returns:
|
||||
Robot instance
|
||||
"""
|
||||
|
||||
|
||||
# Get the appropriate robot config class based on robot_type
|
||||
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
|
||||
robot_config.leader_arms["main"].port = leader_port
|
||||
robot_config.follower_arms["main"].port = follower_port
|
||||
|
||||
|
||||
return make_robot_from_config(robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create argparse for script-specific arguments
|
||||
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||
@@ -125,14 +127,14 @@ if __name__ == "__main__":
|
||||
|
||||
# Only parse known args, leaving robot config args for Hydra if used
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Create robot with the appropriate config
|
||||
robot = make_robot(args.robot_type)
|
||||
|
||||
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
elif args.mode == "ee":
|
||||
find_ee_bounds(robot, args.control_time_s)
|
||||
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
robot.disconnect()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -10,10 +9,8 @@ import torch
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
|
||||
from lerobot.common.envs.configs import ManiskillEnvConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
@@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
@@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
|
||||
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
@@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
"""Adds a time limit to the environment based on fps and control_time."""
|
||||
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
super().__init__(env)
|
||||
self.control_time_s = control_time_s
|
||||
@@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.current_step += 1
|
||||
|
||||
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
terminated = True
|
||||
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
@@ -190,18 +186,18 @@ def make_maniskill(
|
||||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
|
||||
|
||||
# Add observation and image processing
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
||||
env.unwrapped.metadata["render_fps"] = cfg.fps
|
||||
|
||||
|
||||
# Add compatibility wrappers
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
||||
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig):
|
||||
"""Main function to run the ManiSkill environment."""
|
||||
# Create the ManiSkill environment
|
||||
env = make_maniskill(cfg, n_envs=1)
|
||||
|
||||
|
||||
# Reset the environment
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
# Run a simple interaction loop
|
||||
sum_reward = 0
|
||||
for i in range(100):
|
||||
# Sample a random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
|
||||
# Step the environment
|
||||
start_time = time.perf_counter()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
step_time = time.perf_counter() - start_time
|
||||
sum_reward += reward
|
||||
# Log information
|
||||
|
||||
|
||||
# Reset if episode terminated
|
||||
if terminated or truncated:
|
||||
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
||||
sum_reward = 0
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
@@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = ManiskillEnvConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
||||
@@ -32,7 +32,6 @@ from tqdm import tqdm
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
|
||||
Reference in New Issue
Block a user