Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.

Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-21 10:13:43 +00:00
parent 3ffe0cf0f4
commit 546719137a
8 changed files with 67 additions and 44 deletions

View File

@@ -217,7 +217,7 @@ def learner_service_client(
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"maxAttempts": 7, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor

View File

@@ -169,6 +169,25 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
return observation_features, next_observation_features
def start_learner_threads(
cfg: DictConfig,
device: str,
@@ -345,9 +364,6 @@ def add_actor_information_and_train(
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -356,6 +372,7 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@@ -365,6 +382,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -372,6 +390,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -395,6 +415,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -402,6 +423,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -413,7 +436,8 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -422,7 +446,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations
observations=observations,
observation_features=observation_features
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()

View File

@@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
if self.policy.config.vision_encoder_name is not None:
if self.policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v
for k, v in params_dict.items()
if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
# if self.policy.config.vision_encoder_name is not None:
# if self.policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v
# for k, v in params_dict.items()
# if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
return move_state_dict_to_device(params_dict, device="cpu")

View File

@@ -41,7 +41,6 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.image.2"] = img
return_observations["observation.state"] = state
return return_observations