[WIP] Update SAC configuration and environment settings

- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
AdilZouitine
2025-03-27 08:13:20 +00:00
committed by Michel Aractingi
parent 0b5b62c8fb
commit db897a1619
7 changed files with 183 additions and 126 deletions

View File

@@ -73,8 +73,8 @@ def receive_policy(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -85,6 +85,7 @@ def receive_policy(
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -153,8 +154,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -193,8 +194,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -286,10 +287,10 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
online_env = make_robot_env( cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -302,11 +303,7 @@ def act_with_policy(
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
env_cfg=cfg.env,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@@ -322,13 +319,13 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps):
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.training.online_step_before_learning:
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time,
@@ -426,9 +423,9 @@ def act_with_policy(
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.fps is not None:
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time)
busy_wait(1 / cfg.env.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
@@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps:
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
@@ -495,7 +492,7 @@ def establish_learner_connection(
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["actor"] == "threads"
return cfg.policy.concurrency.actor == "threads"
@parser.wrap()
@@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")

View File

@@ -1097,7 +1097,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
@parser.wrap()
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:

View File

@@ -48,6 +48,7 @@ from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
save_training_state,
)
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
@@ -160,13 +161,14 @@ def load_training_state(
try:
# Use the utility function from train_utils which loads the optimizer state
# The function returns (step, updated_optimizer, scheduler)
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# For interaction step, we still need to load the training_state.pt file
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
@@ -222,16 +224,20 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset.
repo_id = None
if cfg.dataset is not None:
repo_id = cfg.dataset.dataset_repo_id
dataset = LeRobotDataset(
repo_id=cfg.dataset.dataset_repo_id,
local_files_only=True,
repo_id=repo_id,
root=dataset_path,
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.policy.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
state_keys=cfg.policy.input_features.keys(),
optimize_memory=True,
)
@@ -298,7 +304,7 @@ def get_observation_features(
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["learner"] == "threads"
return cfg.policy.concurrency.learner == "threads"
def start_learner_threads(
@@ -388,7 +394,7 @@ def start_learner_server(
service = learner_service.LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config["policy_parameters_push_frequency"],
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
transition_queue=transition_queue,
interaction_message_queue=interaction_message_queue,
)
@@ -406,8 +412,8 @@ def start_learner_server(
server,
)
host = cfg.policy.actor_learner_config["learner_host"]
port = cfg.policy.actor_learner_config["learner_port"]
host = cfg.policy.actor_learner_config.learner_host
port = cfg.policy.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
@@ -509,7 +515,6 @@ def add_actor_information_and_train(
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
# TODO(Adil): This don't work anymore !
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
@@ -575,8 +580,8 @@ def add_actor_information_and_train(
device = cfg.policy.device
storage_device = cfg.policy.storage_device
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config["policy_parameters_push_frequency"]
save_checkpoint = cfg.save_checkpoint
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
while True:
@@ -598,7 +603,7 @@ def add_actor_information_and_train(
continue
replay_buffer.add(**transition)
if cfg.dataset.repo_id is not None and transition.get("complementary_info", {}).get(
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
@@ -618,9 +623,6 @@ def add_actor_information_and_train(
mode="train",
custom_step_key="Interaction step"
)
else:
# Log to console if no WandB logger
logging.info(f"Interaction: {interaction_message}")
logging.debug("[LEARNER] Received interactions")
@@ -765,9 +767,6 @@ def add_actor_information_and_train(
mode="train",
custom_step_key="Optimization step"
)
else:
# Log to console if no WandB logger
logging.info(f"Training: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@@ -789,7 +788,7 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
@@ -810,6 +809,15 @@ def add_actor_information_and_train(
scheduler=None
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -820,8 +828,11 @@ def add_actor_information_and_train(
shutil.rmtree(dataset_dir)
# Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset(
dataset_repo_id,
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
@@ -892,8 +903,10 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None.
"""
if cfg.output_dir is None:
raise ValueError("Output directory must be specified in config")
cfg.validate()
# if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config")
if job_name is None:
job_name = cfg.job_name