Clean the code and remove todo

This commit is contained in:
AdilZouitine
2025-04-24 16:10:56 +02:00
parent c58b504a9e
commit b8c2b0bb93
5 changed files with 3 additions and 63 deletions

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class HILSerlConfig:
pass

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -15,8 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: (1) better device management
import math import math
from dataclasses import asdict from dataclasses import asdict
from typing import Callable, List, Literal, Optional, Tuple from typing import Callable, List, Literal, Optional, Tuple
@@ -254,7 +252,6 @@ class SACPolicy(
with torch.no_grad(): with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"] next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
# 2- compute q targets # 2- compute q targets
@@ -378,7 +375,6 @@ class SACPolicy(
) -> Tensor: ) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"] actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward( q_preds = self.critic_forward(

View File

@@ -226,7 +226,6 @@ def act_with_policy(
### Instantiate the policy in both the actor and learner processes ### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance ### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
env_cfg=cfg.env, env_cfg=cfg.env,
@@ -280,7 +279,6 @@ def act_with_policy(
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action # NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]: if "is_intervention" in info and info["is_intervention"]:
# TODO: Check the shape
# NOTE: The action space for demonstration before hand is with the full action space # NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper # but sometimes for example we want to deactivate the gripper
action = info["action_intervention"] action = info["action_intervention"]
@@ -301,16 +299,13 @@ def act_with_policy(
next_state=next_obs, next_state=next_obs,
done=done, done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool complementary_info=info,
) )
) )
# assign obs to the next obs and continue the rollout # assign obs to the next obs and continue the rollout
obs = next_obs obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated: if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
@@ -342,9 +337,10 @@ def act_with_policy(
} }
) )
) )
# Reset intervention counters
sum_reward_episode = 0.0 sum_reward_episode = 0.0
episode_intervention = False episode_intervention = False
# Reset intervention counters
episode_intervention_steps = 0 episode_intervention_steps = 0
episode_total_steps = 0 episode_total_steps = 0
obs, info = online_env.reset() obs, info = online_env.reset()