Clean the code and remove todo

This commit is contained in:
AdilZouitine
2025-04-24 16:10:56 +02:00
parent c58b504a9e
commit b8c2b0bb93
5 changed files with 3 additions and 63 deletions

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class HILSerlConfig:
pass

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -15,8 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: (1) better device management
import math
from dataclasses import asdict
from typing import Callable, List, Literal, Optional, Tuple
@@ -254,7 +252,6 @@ class SACPolicy(
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
# 2- compute q targets
@@ -378,7 +375,6 @@ class SACPolicy(
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward(

View File

@@ -226,7 +226,6 @@ def act_with_policy(
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
@@ -280,7 +279,6 @@ def act_with_policy(
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# TODO: Check the shape
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
@@ -301,16 +299,13 @@ def act_with_policy(
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
complementary_info=info,
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
@@ -342,9 +337,10 @@ def act_with_policy(
}
)
)
# Reset intervention counters
sum_reward_episode = 0.0
episode_intervention = False
# Reset intervention counters
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()