chore(rl): move rl related code to its directory at top level (#2002)
* chore(rl): move rl related code to its directory at top level * chore(style): apply pre-commit to renamed headers * test(rl): fix rl imports * docs(rl): update rl headers doc
This commit is contained in:
@@ -24,7 +24,7 @@ Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
@@ -64,12 +64,6 @@ from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
@@ -96,6 +90,13 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
# Main entry point
|
||||
@@ -25,12 +25,13 @@ from lerobot.robots import ( # noqa: F401
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
)
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.teleoperators import (
|
||||
gamepad, # noqa: F401
|
||||
so101_leader, # noqa: F401
|
||||
)
|
||||
|
||||
from .gym_manipulator import make_robot_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ Examples of usage:
|
||||
|
||||
- Start a learner server for training:
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
python -m lerobot.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
|
||||
@@ -73,7 +73,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
@@ -100,6 +99,8 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
from lerobot.utils.wandb_utils import WandBLogger
|
||||
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
|
||||
@@ -639,7 +640,7 @@ def start_learner(
|
||||
# TODO: Check if its useful
|
||||
_ = ProcessSignalHandler(False, display_pid=True)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
service = LearnerService(
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
|
||||
@@ -649,7 +650,7 @@ def start_learner(
|
||||
)
|
||||
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
|
||||
@@ -670,7 +671,7 @@ def start_learner(
|
||||
|
||||
shutdown_event.wait()
|
||||
logging.info("[LEARNER] Stopping gRPC server...")
|
||||
server.stop(learner_service.SHUTDOWN_TIMEOUT)
|
||||
server.stop(SHUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user