[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by Michel Aractingi
parent bb69cb3c8c
commit 85fe8a3f4e
79 changed files with 2800 additions and 794 deletions

View File

@@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm import tqdm
@@ -30,20 +29,17 @@ from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
@@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@@ -198,8 +198,12 @@ class ReplayBuffer:
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
replay_buffer = cls(
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
replay_buffer.add(
@@ -244,7 +248,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@@ -298,36 +304,40 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@@ -344,7 +354,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@@ -355,7 +371,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
@@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
device=device,
)
assert isinstance(policy, nn.Module)
@@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size
@@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
@@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
logging.info(
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add(
state=obs,
@@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -573,7 +611,9 @@ def train_cli(cfg: dict):
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()