Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies Updated the version of torch and torchvision Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
d48161da1b
commit
ff223c106d
@@ -41,17 +41,17 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def _get_policy_state(self):
|
||||
with self.policy_lock:
|
||||
params_dict = self.policy.actor.state_dict()
|
||||
if self.policy.config.vision_encoder_name is not None:
|
||||
if self.policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v
|
||||
for k, v in params_dict.items()
|
||||
if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
# if self.policy.config.vision_encoder_name is not None:
|
||||
# if self.policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v
|
||||
# for k, v in params_dict.items()
|
||||
# if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
return move_state_dict_to_device(params_dict, device="cpu")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user