Commit Graph

74 Commits

Author SHA1 Message Date
AdilZouitine
5dc7ff6d3c change the tanh distribution to match hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
ee4ebeac9b match target entropy hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
fe7b47f459 stick to hil serl nn architecture
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
044ca3b039 Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-16 16:46:37 +02:00
AdilZouitine
bc36c69b71 fix encoder training 2025-04-16 16:46:37 +02:00
Michel Aractingi
9eec7b8bb0 General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info 2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
a80a9cf379 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
AdilZouitine
7a42af835e fix caching and dataset stats is optional 2025-04-16 16:46:37 +02:00
AdilZouitine
9751328783 Add rounding for safety 2025-04-16 16:46:37 +02:00
AdilZouitine
03b1644bf7 fix sign issue 2025-04-16 16:46:37 +02:00
AdilZouitine
86466b025f Handle gripper penalty 2025-04-16 16:46:37 +02:00
AdilZouitine
54745f111d fix caching 2025-04-16 16:46:37 +02:00
AdilZouitine
2d932b710c Enhance SACPolicy to support shared encoder and optimize action selection
- Cached encoder output in select_action method to reduce redundant computations.
- Updated action selection and grasp critic calls to utilize cached encoder features when available.
2025-04-16 16:46:37 +02:00
AdilZouitine
a54baceabb Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
2025-04-16 16:46:37 +02:00
AdilZouitine
077d18b439 Refactor SACPolicy for improved readability and action dimension handling
- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines.
- Consolidated continuous action dimension calculation to enhance clarity and maintainability.
- Simplified loss return statements in the forward method to improve code structure.
- Ensured grasp critic parameters are included conditionally based on configuration settings.
2025-04-16 16:46:37 +02:00
AdilZouitine
c6cd1475a7 Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
2025-04-16 16:46:37 +02:00
AdilZouitine
e35ee47b07 Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
2025-04-16 16:46:37 +02:00
AdilZouitine
c3f2487026 Refactor SAC configuration and policy to support discrete actions
- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig.
- Added num_discrete_actions parameter to SACConfig for better action handling.
- Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions.
- Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
2025-04-16 16:46:37 +02:00
pre-commit-ci[bot]
f5cfd9fd48 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-16 16:46:37 +02:00
s1lent4gnt
22da1739b1 Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
2025-04-16 16:46:37 +02:00
s1lent4gnt
384eb2cd07 Add grasp critic
- Implemented grasp critic to evaluate gripper actions
- Added corresponding config parameters for tuning
2025-04-16 16:46:37 +02:00
AdilZouitine
026ad463a9 Fix convergence of sac, multiple torch compile on the same model caused divergence 2025-03-31 13:54:21 +00:00
AdilZouitine
8494634d48 Fix cuda graph break 2025-03-31 07:59:56 +00:00
pre-commit-ci[bot]
c05e4835d0 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:20:39 +00:00
AdilZouitine
0150139668 Refactor SACPolicy for improved type annotations and readability
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
2025-03-28 17:18:48 +00:00
AdilZouitine
b3ad63cf6e Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
2025-03-28 17:18:48 +00:00
AdilZouitine
82a6b69e0e Refactor imports in modeling_sac.py for improved organization
- Rearranged import statements for better readability.
- Removed unused imports and streamlined the code structure.
2025-03-28 17:18:48 +00:00
Michel Aractingi
02b9ea9446 Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
2025-03-28 17:18:48 +00:00
AdilZouitine
052a4acfc2 [WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
2025-03-28 17:18:48 +00:00
AdilZouitine
f483931fc0 Handle new config with sac 2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
7c05755823 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
8e6d5f504c [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
81952b2092 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
AdilZouitine
0eef49a0f6 Initialize log_alpha with the logarithm of temperature_init in SACPolicy
- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
2d5effeeba [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
AdilZouitine
c5c921cd7c Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary.
- Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity.
- Simplified the assertion in the `Policy` class for better readability.
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
cb272294f5 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
AdilZouitine
4bb2077afa Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
2025-03-28 17:18:48 +00:00
AdilZouitine
7960f2c3c1 Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
dee154a1a5 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
AdilZouitine
a3ef7dc6c3 Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
2025-03-28 17:18:48 +00:00
AdilZouitine
7e3e1ce173 Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
2025-03-28 17:18:48 +00:00
pre-commit-ci[bot]
38f5fa4523 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-28 17:18:48 +00:00
Michel Aractingi
ff223c106d Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:48 +00:00
AdilZouitine
150def839c Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
2025-03-28 17:18:24 +00:00
Michel Aractingi
795063aa1b - Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the update_policy_parameters and remove strict=False
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:24 +00:00
Michel Aractingi
eb7e28d9d9 Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:24 +00:00
Michel Aractingi
a0e0a9a9b1 fix log_alpha in modeling_sac: change to nn.parameter
added pretrained vision model in policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:24 +00:00
Michel Aractingi
f4f5b26a21 Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:24 +00:00
Michel Aractingi
b29401e4e2 - Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-28 17:18:24 +00:00