Clean the code and remove todo
This commit is contained in:
@@ -1,23 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HILSerlConfig:
|
|
||||||
pass
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
|
||||||
|
|
||||||
|
|
||||||
class HILSerlPolicy(
|
|
||||||
nn.Module,
|
|
||||||
PyTorchModelHubMixin,
|
|
||||||
library_name="lerobot",
|
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
|
||||||
tags=["robotics", "hilserl"],
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
@@ -15,8 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# TODO: (1) better device management
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Callable, List, Literal, Optional, Tuple
|
from typing import Callable, List, Literal, Optional, Tuple
|
||||||
@@ -254,7 +252,6 @@ class SACPolicy(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||||
|
|
||||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
|
||||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
|
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
@@ -378,7 +375,6 @@ class SACPolicy(
|
|||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
|
||||||
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
|
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
|
|||||||
@@ -226,7 +226,6 @@ def act_with_policy(
|
|||||||
### Instantiate the policy in both the actor and learner processes
|
### Instantiate the policy in both the actor and learner processes
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
# TODO: At some point we should just need make sac policy
|
|
||||||
policy: SACPolicy = make_policy(
|
policy: SACPolicy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
@@ -280,7 +279,6 @@ def act_with_policy(
|
|||||||
|
|
||||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||||
if "is_intervention" in info and info["is_intervention"]:
|
if "is_intervention" in info and info["is_intervention"]:
|
||||||
# TODO: Check the shape
|
|
||||||
# NOTE: The action space for demonstration before hand is with the full action space
|
# NOTE: The action space for demonstration before hand is with the full action space
|
||||||
# but sometimes for example we want to deactivate the gripper
|
# but sometimes for example we want to deactivate the gripper
|
||||||
action = info["action_intervention"]
|
action = info["action_intervention"]
|
||||||
@@ -301,16 +299,13 @@ def act_with_policy(
|
|||||||
next_state=next_obs,
|
next_state=next_obs,
|
||||||
done=done,
|
done=done,
|
||||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
complementary_info=info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# assign obs to the next obs and continue the rollout
|
# assign obs to the next obs and continue the rollout
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
|
||||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
|
||||||
# Because we are using a single environment we can index at zero
|
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
# TODO: Handle logging for episode information
|
|
||||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||||
|
|
||||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||||
@@ -342,9 +337,10 @@ def act_with_policy(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reset intervention counters
|
||||||
sum_reward_episode = 0.0
|
sum_reward_episode = 0.0
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
# Reset intervention counters
|
|
||||||
episode_intervention_steps = 0
|
episode_intervention_steps = 0
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|||||||
Reference in New Issue
Block a user