forked from tangger/lerobot
Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models. - Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`. - Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability. - Removed redundant code and improved documentation for clarity.
This commit is contained in:
committed by
Michel Aractingi
parent
eb710647bf
commit
6e687e2910
@@ -135,18 +135,6 @@ class SACConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
output_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -177,7 +177,64 @@ class SACPolicy(
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||
# Extract common components from batch
|
||||
actions = batch["action"]
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
|
||||
return self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
if model == "temperature":
|
||||
return self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
@@ -257,7 +314,11 @@ class SACPolicy(
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
Reference in New Issue
Block a user