Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
committed by
AdilZouitine
parent
79e0f6e06c
commit
02b9ea9446
@@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
|
||||
@@ -53,9 +52,7 @@ class SACPolicy(
|
||||
self.config = config
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
@@ -64,12 +61,10 @@ class SACPolicy(
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.dataset_stats
|
||||
)
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -138,7 +133,6 @@ class SACPolicy(
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
@@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module):
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_features["observation.image"].shape[0],
|
||||
in_channels=config.input_features[image_key].shape[0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
@@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module):
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
@@ -844,8 +841,10 @@ if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
main()
|
||||
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user