[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions

View File

@@ -44,9 +44,7 @@ def main():
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = (
Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
)
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
@@ -72,9 +70,7 @@ def main():
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = (
torch.from_numpy(uint_chw_array) / 255.0
)
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]

View File

@@ -54,9 +54,7 @@ def get_paligemma_config(precision: str):
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(
text_config=text_config, vision_config=vision_config, **config
)
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config

View File

@@ -322,9 +322,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str
):
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
@@ -384,9 +382,7 @@ def convert_pi0_checkpoint(
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(
paligemma_params, "model.paligemma_with_expert."
)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")

View File

@@ -193,9 +193,7 @@ def aloha_gripper_to_angular(value):
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
2 * horn_radius * linear_position
)
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
@@ -246,9 +244,7 @@ class PI0Policy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -256,9 +252,7 @@ class PI0Policy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224"
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -271,9 +265,7 @@ class PI0Policy(PreTrainedPolicy):
return self.parameters()
@torch.no_grad
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None
) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -312,9 +304,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(
self, batch: dict[str, Tensor], noise=None, time=None
) -> tuple[Tensor, dict[str, Tensor]]:
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@@ -330,9 +320,7 @@ class PI0Policy(PreTrainedPolicy):
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
)
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
@@ -359,9 +347,7 @@ class PI0Policy(PreTrainedPolicy):
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [
key for key in self.config.image_features if key not in batch
]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
@@ -373,9 +359,7 @@ class PI0Policy(PreTrainedPolicy):
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img, *self.config.resize_imgs_with_padding, pad_value=0
)
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
@@ -414,9 +398,7 @@ class PI0Policy(PreTrainedPolicy):
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(
device=device, dtype=torch.bool
)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
@@ -435,9 +417,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(
actions[:, :, motor_idx]
)
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
@@ -446,9 +426,7 @@ class PI0Policy(PreTrainedPolicy):
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
actions[:, :, motor_idx]
)
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
@@ -498,25 +476,15 @@ class PI0FlowMatching(nn.Module):
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_with_export_config
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(
self.config.max_action_dim, self.config.proj_width
)
self.action_out_proj = nn.Linear(
self.config.proj_width, self.config.max_action_dim
)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(
self.config.proj_width * 2, self.config.proj_width
)
self.action_time_mlp_out = nn.Linear(
self.config.proj_width, self.config.proj_width
)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
@@ -560,9 +528,7 @@ class PI0FlowMatching(nn.Module):
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
)
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
@@ -637,9 +603,7 @@ class PI0FlowMatching(nn.Module):
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(
bsize, action_time_dim, dtype=torch.bool, device=device
)
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
@@ -677,9 +641,7 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, time
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -703,9 +665,7 @@ class PI0FlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -763,16 +723,12 @@ class PI0FlowMatching(nn.Module):
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, timestep
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
batch_size, suffix_len, prefix_len
)
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

View File

@@ -39,13 +39,9 @@ def apply_rope(x, positions, max_wavelength=10_000):
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
d_half, dtype=torch.float32, device=device
)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
torch.float32
)
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
@@ -178,9 +174,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(
config=config.paligemma_config
)
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
@@ -297,9 +291,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat(
[past_key_values[layer_idx]["key_states"], key_states], dim=1
)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states],
dim=1,
@@ -392,9 +384,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
value_states,
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = (
self.config.paligemma_config.text_config.num_key_value_heads
)
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
@@ -442,9 +432,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(
attention_mask[:, None, :, :], att_weights, big_neg
)
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
@@ -456,8 +444,6 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
)
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output