From 5998203a33b395143af48a44bcdfef72bcecb022 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 5 May 2025 11:33:09 +0200 Subject: [PATCH] [Port HIL-SERL] Final fixes for reward classifier (#1067) Co-authored-by: s1lent4gnt Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- lerobot/common/datasets/utils.py | 2 +- .../reward_model/configuration_classifier.py | 2 + .../reward_model/modeling_classifier.py | 118 ++- lerobot/scripts/server/gym_manipulator.py | 879 +++++++++++++++--- .../hilserl/test_modeling_classifier.py | 12 +- 5 files changed, 845 insertions(+), 168 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 9d8a54db..4a5874af 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -409,7 +409,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea names = ft["names"] # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + if names and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) elif key == "observation.environment_state": type = FeatureType.ENV diff --git a/lerobot/common/policies/reward_model/configuration_classifier.py b/lerobot/common/policies/reward_model/configuration_classifier.py index 185f54d7..735236fc 100644 --- a/lerobot/common/policies/reward_model/configuration_classifier.py +++ b/lerobot/common/policies/reward_model/configuration_classifier.py @@ -15,6 +15,8 @@ class RewardClassifierConfig(PreTrainedConfig): name: str = "reward_classifier" num_classes: int = 2 hidden_dim: int = 256 + latent_dim: int = 256 + image_embedding_pooling_dim: int = 8 dropout_rate: float = 0.1 model_name: str = "helper2424/resnet10" device: str = "cpu" diff --git a/lerobot/common/policies/reward_model/modeling_classifier.py b/lerobot/common/policies/reward_model/modeling_classifier.py index c998f83d..476185db 100644 --- a/lerobot/common/policies/reward_model/modeling_classifier.py +++ b/lerobot/common/policies/reward_model/modeling_classifier.py @@ -34,6 +34,59 @@ class ClassifierOutput: ) +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch + Returns: + Output tensor of shape [B, C*F] or [C*F] if no batch + """ + + features = features.last_hidden_state + + original_shape = features.shape + if features.dim() == 3: + features = features.unsqueeze(0) # Add batch dim + + features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + # Remove batch dim + if len(original_shape) == 3: + output = output.squeeze(0) + + return output + + class Classifier(PreTrainedPolicy): """Image classifier built on top of a pre-trained encoder.""" @@ -78,6 +131,18 @@ class Classifier(PreTrainedPolicy): self._setup_cnn_backbone() self._freeze_encoder() + + # Extract image keys from input_features + self.image_keys = [ + key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE) + ] + + if self.is_cnn: + self.encoders = nn.ModuleDict() + for image_key in self.image_keys: + encoder = self._create_single_encoder() + self.encoders[image_key] = encoder + self._build_classifier_head() def _setup_cnn_backbone(self): @@ -95,11 +160,28 @@ class Classifier(PreTrainedPolicy): for param in self.encoder.parameters(): param.requires_grad = False + def _create_single_encoder(self): + encoder = nn.Sequential( + self.encoder, + SpatialLearnedEmbeddings( + height=4, + width=4, + channel=self.feature_dim, + num_features=self.config.image_embedding_pooling_dim, + ), + nn.Dropout(self.config.dropout_rate), + nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + return encoder + def _build_classifier_head(self) -> None: """Initialize the classifier head architecture.""" # Get input dimension based on model type if self.is_cnn: - input_dim = self.feature_dim + input_dim = self.config.latent_dim else: # Transformer models if hasattr(self.encoder.config, "hidden_size"): input_dim = self.encoder.config.hidden_size @@ -117,26 +199,20 @@ class Classifier(PreTrainedPolicy): ), ) - def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: + def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor: """Extract the appropriate output from the encoder.""" with torch.no_grad(): if self.is_cnn: # The HF ResNet applies pooling internally - outputs = self.encoder(x) - # Get pooled output directly - features = outputs.pooler_output - - if features.dim() > 2: - features = features.squeeze(-1).squeeze(-1) - return features + outputs = self.encoders[image_key](x) + return outputs else: # Transformer models outputs = self.encoder(x) - if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: - return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]: """Extract image tensors and label tensors from batch.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] labels = batch["next.reward"] @@ -144,7 +220,9 @@ class Classifier(PreTrainedPolicy): def predict(self, xs: list) -> ClassifierOutput: """Forward pass of the classifier for inference.""" - encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs]) + encoder_outputs = torch.hstack( + [self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)] + ) logits = self.classifier_head(encoder_outputs) if self.config.num_classes == 2: @@ -192,8 +270,14 @@ class Classifier(PreTrainedPolicy): return loss, output_dict def predict_reward(self, batch, threshold=0.5): - """Legacy method for compatibility.""" + """Eval method. Returns predicted reward with the decision threshold as argument.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images from batch dict images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + if self.config.num_classes == 2: probs = self.predict(images).probabilities logging.debug(f"Predicted reward images: {probs}") @@ -201,13 +285,9 @@ class Classifier(PreTrainedPolicy): else: return torch.argmax(self.predict(images).probabilities, dim=1) - def get_optim_params(self) -> dict: + def get_optim_params(self): """Return optimizer parameters for the policy.""" - return { - "params": self.parameters(), - "lr": getattr(self.config, "learning_rate", 1e-4), - "weight_decay": getattr(self.config, "weight_decay", 0.01), - } + return self.parameters() def select_action(self, batch: Dict[str, Tensor]) -> Tensor: """ diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 712ddf28..48ba91e3 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -22,7 +22,135 @@ from lerobot.configs import parser from lerobot.scripts.server.kinematics import RobotKinematics logging.basicConfig(level=logging.INFO) -MAX_GRIPPER_COMMAND = 40 +MAX_GRIPPER_COMMAND = 30 + + +class TorchBox(gym.spaces.Box): + """ + A version of gym.spaces.Box that handles PyTorch tensors. + + This class extends gym.spaces.Box to work with PyTorch tensors, + providing compatibility between NumPy arrays and PyTorch tensors. + """ + + def __init__( + self, + low: float | Sequence[float] | np.ndarray, + high: float | Sequence[float] | np.ndarray, + shape: Sequence[int] | None = None, + np_dtype: np.dtype | type = np.float32, + torch_dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int | np.random.Generator | None = None, + ) -> None: + """ + Initialize the PyTorch-compatible Box space. + + Args: + low: Lower bounds of the space. + high: Upper bounds of the space. + shape: Shape of the space. If None, inferred from low and high. + np_dtype: NumPy data type for internal storage. + torch_dtype: PyTorch data type for tensor conversion. + device: PyTorch device for returned tensors. + seed: Random seed for sampling. + """ + super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) + self.torch_dtype = torch_dtype + self.device = device + + def sample(self) -> torch.Tensor: + """ + Sample a random point from the space. + + Returns: + A PyTorch tensor within the space bounds. + """ + arr = super().sample() + return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) + + def contains(self, x: torch.Tensor) -> bool: + """ + Check if a tensor is within the space bounds. + + Args: + x: The PyTorch tensor to check. + + Returns: + Boolean indicating whether the tensor is within bounds. + """ + # Move to CPU/numpy and cast to the internal dtype + arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) + return super().contains(arr) + + def seed(self, seed: int | np.random.Generator | None = None): + """ + Set the random seed for sampling. + + Args: + seed: The random seed to use. + + Returns: + List containing the seed. + """ + super().seed(seed) + return [seed] + + def __repr__(self) -> str: + """ + Return a string representation of the space. + + Returns: + Formatted string with space details. + """ + return ( + f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " + f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" + ) + + +class TorchActionWrapper(gym.Wrapper): + """ + Wrapper that changes the action space to use PyTorch tensors. + + This wrapper modifies the action space to return PyTorch tensors when sampled + and handles converting PyTorch actions to NumPy when stepping the environment. + """ + + def __init__(self, env: gym.Env, device: str): + """ + Initialize the PyTorch action space wrapper. + + Args: + env: The environment to wrap. + device: The PyTorch device to use for tensor operations. + """ + super().__init__(env) + self.action_space = TorchBox( + low=env.action_space.low, + high=env.action_space.high, + shape=env.action_space.shape, + torch_dtype=torch.float32, + device=torch.device("cpu"), + ) + + def step(self, action: torch.Tensor): + """ + Step the environment with a PyTorch tensor action. + + This method handles conversion from PyTorch tensors to NumPy arrays + for compatibility with the underlying environment. + + Args: + action: PyTorch tensor action to take. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + if action.dim() == 2: + action = action.squeeze(0) + action = action.detach().cpu().numpy() + return self.env.step(action) class RobotEnv(gym.Env): @@ -45,9 +173,9 @@ class RobotEnv(gym.Env): The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup supports both relative (delta) adjustments and absolute joint positions for controlling the robot. - cfg. + Args: robot: The robot interface object used to connect and interact with the physical robot. - display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. + display_cameras: If True, the robot's camera feeds will be displayed during execution. """ super().__init__() @@ -113,14 +241,14 @@ class RobotEnv(gym.Env): Reset the environment to its initial state. This method resets the step counter and clears any episodic data. - cfg. - seed (Optional[int]): A seed for random number generation to ensure reproducibility. - options (Optional[dict]): Additional options to influence the reset behavior. + Args: + seed: A seed for random number generation to ensure reproducibility. + options: Additional options to influence the reset behavior. Returns: A tuple containing: - observation (dict): The initial sensor observation. - - info (dict): A dictionary with supplementary information, including the key "initial_position". + - info (dict): A dictionary with supplementary information, including the key "is_intervention". """ super().reset(seed=seed, options=options) @@ -140,16 +268,16 @@ class RobotEnv(gym.Env): The provided action is processed and sent to the robot as joint position commands that may be either absolute values or deltas based on the environment configuration. - cfg. - action (np.ndarray or torch.Tensor): The commanded joint positions. + Args: + action: The commanded joint positions as a numpy array or torch tensor. Returns: - tuple: A tuple containing: + A tuple containing: - observation (dict): The new sensor observation after taking the step. - reward (float): The step reward (default is 0.0 within this wrapper). - terminated (bool): True if the episode has reached a terminal state. - truncated (bool): True if the episode was truncated (e.g., time constraints). - - info (dict): Additional debugging information including: + - info (dict): Additional debugging information including intervention status. """ self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") @@ -198,7 +326,23 @@ class RobotEnv(gym.Env): class AddJointVelocityToObservation(gym.ObservationWrapper): + """ + Wrapper that adds joint velocity information to the observation. + + This wrapper computes joint velocities by tracking changes in joint positions over time, + and extends the observation space to include these velocities. + """ + def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): + """ + Initialize the joint velocity wrapper. + + Args: + env: The environment to wrap. + joint_velocity_limits: Maximum expected joint velocity for space bounds. + fps: Frames per second used to calculate velocity (position delta / time). + num_dof: Number of degrees of freedom (joints) in the robot. + """ super().__init__(env) # Extend observation space to include joint velocities @@ -223,6 +367,15 @@ class AddJointVelocityToObservation(gym.ObservationWrapper): self.dt = 1.0 / fps def observation(self, observation): + """ + Add joint velocity information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with joint velocities. + """ joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt self.last_joint_positions = observation["observation.state"].clone() observation["observation.state"] = torch.cat( @@ -232,7 +385,22 @@ class AddJointVelocityToObservation(gym.ObservationWrapper): class AddCurrentToObservation(gym.ObservationWrapper): + """ + Wrapper that adds motor current information to the observation. + + This wrapper extends the observation space to include the current values + from each motor, providing information about the forces being applied. + """ + def __init__(self, env, max_current=500, num_dof=6): + """ + Initialize the current observation wrapper. + + Args: + env: The environment to wrap. + max_current: Maximum expected current for space bounds. + num_dof: Number of degrees of freedom (joints) in the robot. + """ super().__init__(env) # Extend observation space to include joint velocities @@ -253,6 +421,15 @@ class AddCurrentToObservation(gym.ObservationWrapper): ) def observation(self, observation): + """ + Add current information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with current values. + """ present_current = ( self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32) ) @@ -263,14 +440,14 @@ class AddCurrentToObservation(gym.ObservationWrapper): class RewardWrapper(gym.Wrapper): - def __init__(self, env, reward_classifier, device: torch.device = "cuda"): + def __init__(self, env, reward_classifier, device="cuda"): """ - Wrapper to add reward prediction to the environment, it use a trained classifier. + Wrapper to add reward prediction to the environment using a trained classifier. - cfg. - env: The environment to wrap - reward_classifier: The reward classifier model - device: The device to run the model on + Args: + env: The environment to wrap. + reward_classifier: The reward classifier model. + device: The device to run the model on. """ self.env = env @@ -280,21 +457,34 @@ class RewardWrapper(gym.Wrapper): self.reward_classifier.to(self.device) def step(self, action): + """ + Execute a step and compute the reward using the classifier. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ observation, _, terminated, truncated, info = self.env.step(action) - images = { - key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") - for key in observation - if "image" in key - } + + images = {} + for key in observation: + if "image" in key: + images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) + if images[key].dim() == 3: + images[key] = images[key].unsqueeze(0) + start_time = time.perf_counter() with torch.inference_mode(): success = ( - self.reward_classifier.predict_reward(images, threshold=0.8) + self.reward_classifier.predict_reward(images, threshold=0.7) if self.reward_classifier is not None else 0.0 ) info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) + reward = 0.0 if success == 1.0: terminated = True reward = 1.0 @@ -302,11 +492,36 @@ class RewardWrapper(gym.Wrapper): return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): + """ + Reset the environment. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ return self.env.reset(seed=seed, options=options) class TimeLimitWrapper(gym.Wrapper): + """ + Wrapper that adds a time limit to episodes and tracks execution time. + + This wrapper terminates episodes after a specified time has elapsed, providing + better control over episode length. + """ + def __init__(self, env, control_time_s, fps): + """ + Initialize the time limit wrapper. + + Args: + env: The environment to wrap. + control_time_s: Maximum episode duration in seconds. + fps: Frames per second for calculating the maximum number of steps. + """ self.env = env self.control_time_s = control_time_s self.fps = fps @@ -319,6 +534,15 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 def step(self, action): + """ + Step the environment and track time elapsed. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ obs, reward, terminated, truncated, info = self.env.step(action) time_since_last_step = time.perf_counter() - self.last_timestamp self.episode_time_in_s += time_since_last_step @@ -333,6 +557,16 @@ class TimeLimitWrapper(gym.Wrapper): return obs, reward, terminated, truncated, info def reset(self, seed=None, options=None): + """ + Reset the environment and time tracking. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ self.episode_time_in_s = 0.0 self.last_timestamp = time.perf_counter() self.current_step = 0 @@ -340,12 +574,28 @@ class TimeLimitWrapper(gym.Wrapper): class ImageCropResizeWrapper(gym.Wrapper): + """ + Wrapper that crops and resizes image observations. + + This wrapper processes image observations to focus on relevant regions by + cropping and then resizing to a standard size. + """ + def __init__( self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None, ): + """ + Initialize the image crop and resize wrapper. + + Args: + env: The environment to wrap. + crop_params_dict: Dictionary mapping image observation keys to crop parameters + (top, left, height, width). + resize_size: Target size for resized images (height, width). Defaults to (128, 128). + """ super().__init__(env) self.env = env self.crop_params_dict = crop_params_dict @@ -363,6 +613,15 @@ class ImageCropResizeWrapper(gym.Wrapper): self.resize_size = (128, 128) def step(self, action): + """ + Step the environment and process image observations. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with processed images. + """ obs, reward, terminated, truncated, info = self.env.step(action) for k in self.crop_params_dict: device = obs[k].device @@ -393,6 +652,16 @@ class ImageCropResizeWrapper(gym.Wrapper): return obs, reward, terminated, truncated, info def reset(self, seed=None, options=None): + """ + Reset the environment and process image observations. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + Tuple of (observation, info) with processed images. + """ obs, info = self.env.reset(seed=seed, options=options) for k in self.crop_params_dict: device = obs[k].device @@ -406,12 +675,35 @@ class ImageCropResizeWrapper(gym.Wrapper): class ConvertToLeRobotObservation(gym.ObservationWrapper): + """ + Wrapper that converts standard observations to LeRobot format. + + This wrapper processes observations to match the expected format for LeRobot, + including normalizing image values and moving tensors to the specified device. + """ + def __init__(self, env, device: str = "cpu"): + """ + Initialize the LeRobot observation converter. + + Args: + env: The environment to wrap. + device: Target device for the observation tensors. + """ super().__init__(env) self.device = torch.device(device) def observation(self, observation): + """ + Convert observations to LeRobot format. + + Args: + observation: The original observation from the environment. + + Returns: + The processed observation with normalized images and proper tensor formats. + """ for key in observation: observation[key] = observation[key].float() if "image" in key: @@ -426,18 +718,46 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper): class ResetWrapper(gym.Wrapper): + """ + Wrapper that handles environment reset procedures. + + This wrapper provides additional functionality during environment reset, + including the option to reset to a fixed pose or allow manual reset. + """ + def __init__( self, env: RobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, ): + """ + Initialize the reset wrapper. + + Args: + env: The environment to wrap. + reset_pose: Fixed joint positions to reset to. If None, manual reset is used. + reset_time_s: Time in seconds to wait after reset or allowed for manual reset. + """ super().__init__(env) self.reset_time_s = reset_time_s self.reset_pose = reset_pose self.robot = self.unwrapped.robot def reset(self, *, seed=None, options=None): + """ + Reset the environment with either fixed or manual reset procedure. + + If reset_pose is provided, the robot will move to that position. + Otherwise, manual teleoperation control is allowed for reset_time_s seconds. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ start_time = time.perf_counter() if self.reset_pose is not None: log_say("Reset the environment.", play_sounds=True) @@ -466,10 +786,32 @@ class ResetWrapper(gym.Wrapper): class BatchCompatibleWrapper(gym.ObservationWrapper): + """ + Wrapper that ensures observations are compatible with batch processing. + + This wrapper adds a batch dimension to observations that don't already have one, + making them compatible with models that expect batched inputs. + """ + def __init__(self, env): + """ + Initialize the batch compatibility wrapper. + + Args: + env: The environment to wrap. + """ super().__init__(env) def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Add batch dimensions to observations if needed. + + Args: + observation: Dictionary of observation tensors. + + Returns: + Dictionary of observation tensors with batch dimensions. + """ for key in observation: if "image" in key and observation[key].dim() == 3: observation[key] = observation[key].unsqueeze(0) @@ -481,12 +823,36 @@ class BatchCompatibleWrapper(gym.ObservationWrapper): class GripperPenaltyWrapper(gym.RewardWrapper): + """ + Wrapper that adds penalties for inefficient gripper commands. + + This wrapper modifies rewards to discourage excessive gripper movement + or commands that attempt to move the gripper beyond its physical limits. + """ + def __init__(self, env, penalty: float = -0.1): + """ + Initialize the gripper penalty wrapper. + + Args: + env: The environment to wrap. + penalty: Negative reward value to apply for inefficient gripper actions. + """ super().__init__(env) self.penalty = penalty self.last_gripper_state = None def reward(self, reward, action): + """ + Apply penalties to reward based on gripper actions. + + Args: + reward: The original reward from the environment. + action: The action that was taken. + + Returns: + Modified reward with penalty applied if necessary. + """ gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND @@ -498,6 +864,15 @@ class GripperPenaltyWrapper(gym.RewardWrapper): return reward + self.penalty * int(gripper_penalty_bool) def step(self, action): + """ + Step the environment and apply gripper penalties. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with penalty applied. + """ self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) @@ -508,6 +883,15 @@ class GripperPenaltyWrapper(gym.RewardWrapper): return obs, reward, terminated, truncated, info def reset(self, **kwargs): + """ + Reset the environment and penalty tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info with gripper penalty initialized. + """ self.last_gripper_state = None obs, info = super().reset(**kwargs) info["gripper_penalty"] = 0.0 @@ -515,7 +899,22 @@ class GripperPenaltyWrapper(gym.RewardWrapper): class GripperActionWrapper(gym.ActionWrapper): + """ + Wrapper that processes gripper control commands. + + This wrapper quantizes and processes gripper commands, adding a sleep time between + consecutive gripper actions to prevent rapid toggling. + """ + def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): + """ + Initialize the gripper action wrapper. + + Args: + env: The environment to wrap. + quantization_threshold: Threshold below which gripper commands are quantized to zero. + gripper_sleep: Minimum time in seconds between consecutive gripper commands. + """ super().__init__(env) self.quantization_threshold = quantization_threshold self.gripper_sleep = gripper_sleep @@ -523,6 +922,15 @@ class GripperActionWrapper(gym.ActionWrapper): self.last_gripper_action = None def action(self, action): + """ + Process gripper commands in the action. + + Args: + action: The original action from the agent. + + Returns: + Modified action with processed gripper command. + """ if self.gripper_sleep > 0.0: if ( self.last_gripper_action is not None @@ -550,6 +958,15 @@ class GripperActionWrapper(gym.ActionWrapper): return action def reset(self, **kwargs): + """ + Reset the gripper action tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ obs, info = super().reset(**kwargs) self.last_gripper_action_time = 0.0 self.last_gripper_action = None @@ -557,7 +974,22 @@ class GripperActionWrapper(gym.ActionWrapper): class EEActionWrapper(gym.ActionWrapper): + """ + Wrapper that converts end-effector space actions to joint space actions. + + This wrapper takes actions defined in cartesian space (x, y, z, gripper) and + converts them to joint space actions using inverse kinematics. + """ + def __init__(self, env, ee_action_space_params=None, use_gripper=False): + """ + Initialize the end-effector action wrapper. + + Args: + env: The environment to wrap. + ee_action_space_params: Parameters defining the end-effector action space. + use_gripper: Whether to include gripper control in the action space. + """ super().__init__(env) self.ee_action_space_params = ee_action_space_params self.use_gripper = use_gripper @@ -592,6 +1024,15 @@ class EEActionWrapper(gym.ActionWrapper): self.bounds = ee_action_space_params.bounds def action(self, action): + """ + Convert end-effector action to joint space action. + + Args: + action: End-effector action in cartesian space. + + Returns: + Converted action in joint space. + """ desired_ee_pos = np.eye(4) if self.use_gripper: @@ -618,7 +1059,21 @@ class EEActionWrapper(gym.ActionWrapper): class EEObservationWrapper(gym.ObservationWrapper): + """ + Wrapper that adds end-effector pose information to observations. + + This wrapper computes the end-effector pose using forward kinematics + and adds it to the observation space. + """ + def __init__(self, env, ee_pose_limits): + """ + Initialize the end-effector observation wrapper. + + Args: + env: The environment to wrap. + ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. + """ super().__init__(env) # Extend observation space to include end effector pose @@ -637,6 +1092,15 @@ class EEObservationWrapper(gym.ObservationWrapper): self.fk_function = self.kinematics.fk_gripper_tip def observation(self, observation): + """ + Add end-effector pose to the observation. + + Args: + observation: Original observation from the environment. + + Returns: + Enhanced observation with end-effector pose information. + """ current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_ee_pos = self.fk_function(current_joint_pos) observation["observation.state"] = torch.cat( @@ -655,11 +1119,25 @@ class EEObservationWrapper(gym.ObservationWrapper): class BaseLeaderControlWrapper(gym.Wrapper): - """Base class for leader-follower robot control wrappers.""" + """ + Base class for leader-follower robot control wrappers. + + This wrapper enables human intervention through a leader-follower robot setup, + where the human can control a leader robot to guide the follower robot's movements. + """ def __init__( self, env, use_geared_leader_arm: bool = False, ee_action_space_params=None, use_gripper=False ): + """ + Initialize the base leader control wrapper. + + Args: + env: The environment to wrap. + use_geared_leader_arm: Whether to use a geared leader arm setup. + ee_action_space_params: Parameters defining the end-effector action space. + use_gripper: Whether to include gripper control. + """ super().__init__(env) self.robot_leader = env.unwrapped.robot.leader_arms["main"] self.robot_follower = env.unwrapped.robot.follower_arms["main"] @@ -692,7 +1170,12 @@ class BaseLeaderControlWrapper(gym.Wrapper): self._init_keyboard_listener() def _init_keyboard_events(self): - """Initialize the keyboard events dictionary - override in subclasses.""" + """ + Initialize the keyboard events dictionary. + + This method sets up tracking for keyboard events used for intervention control. + It should be overridden in subclasses to add additional events. + """ self.keyboard_events = { "episode_success": False, "episode_end": False, @@ -700,7 +1183,15 @@ class BaseLeaderControlWrapper(gym.Wrapper): } def _handle_key_press(self, key, keyboard): - """Handle key presses - override in subclasses for additional keys.""" + """ + Handle key press events. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + This method should be overridden in subclasses for additional key handling. + """ try: if key == keyboard.Key.esc: self.keyboard_events["episode_end"] = True @@ -716,7 +1207,11 @@ class BaseLeaderControlWrapper(gym.Wrapper): logging.error(f"Error handling key press: {e}") def _init_keyboard_listener(self): - """Initialize keyboard listener if not in headless mode""" + """ + Initialize the keyboard listener for intervention control. + + This method sets up keyboard event handling if not in headless mode. + """ if is_headless(): logging.warning( "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." @@ -737,11 +1232,26 @@ class BaseLeaderControlWrapper(gym.Wrapper): self.listener = None def _check_intervention(self): - """Check if intervention is needed - override in subclasses.""" + """ + Check if human intervention is needed. + + Returns: + Boolean indicating whether intervention is needed. + + This method should be overridden in subclasses with specific intervention logic. + """ return False def _handle_intervention(self, action): - """Process actions during intervention mode.""" + """ + Process actions during intervention mode. + + Args: + action: The original action from the agent. + + Returns: + Tuple of (modified_action, intervention_action). + """ if self.leader_torque_enabled: self.robot_leader.write("Torque_Enable", 0) self.leader_torque_enabled = False @@ -784,7 +1294,11 @@ class BaseLeaderControlWrapper(gym.Wrapper): return action, action_intervention def _handle_leader_teleoperation(self): - """Handle leader teleoperation (non-intervention) operation.""" + """ + Handle leader teleoperation in non-intervention mode. + + This method synchronizes the leader robot position with the follower. + """ if not self.leader_torque_enabled: self.robot_leader.write("Torque_Enable", 1) self.leader_torque_enabled = True @@ -793,7 +1307,15 @@ class BaseLeaderControlWrapper(gym.Wrapper): self.robot_leader.write("Goal_Position", follower_pos) def step(self, action): - """Execute environment step with possible intervention.""" + """ + Execute a step with possible human intervention. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ is_intervention = self._check_intervention() action_intervention = None @@ -821,29 +1343,60 @@ class BaseLeaderControlWrapper(gym.Wrapper): return obs, reward, terminated, truncated, info def reset(self, **kwargs): - """Reset the environment and internal state.""" + """ + Reset the environment and intervention state. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ self.prev_leader_ee = None self.prev_leader_pos = None self.keyboard_events = dict.fromkeys(self.keyboard_events, False) return super().reset(**kwargs) def close(self): - """Clean up resources.""" + """ + Clean up resources, including stopping keyboard listener. + + Returns: + Result of closing the wrapped environment. + """ if hasattr(self, "listener") and self.listener is not None: self.listener.stop() return self.env.close() class GearedLeaderControlWrapper(BaseLeaderControlWrapper): - """Wrapper that enables manual intervention via keyboard.""" + """ + Wrapper that enables manual intervention via keyboard. + + This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling + of human intervention mode with keyboard controls. + """ def _init_keyboard_events(self): - """Initialize keyboard events including human intervention flag.""" + """ + Initialize keyboard events including human intervention flag. + + Extends the base class dictionary with an additional flag for tracking + intervention state toggled by keyboard. + """ super()._init_keyboard_events() self.keyboard_events["human_intervention_step"] = False def _handle_key_press(self, key, keyboard): - """Handle key presses including space for intervention toggle.""" + """ + Handle key presses including space for intervention toggle. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + Extends the base handler to respond to space key for toggling intervention. + """ super()._handle_key_press(key, keyboard) if key == keyboard.Key.space: if not self.keyboard_events["human_intervention_step"]: @@ -859,12 +1412,22 @@ class GearedLeaderControlWrapper(BaseLeaderControlWrapper): log_say("Continuing with policy actions.", play_sounds=True) def _check_intervention(self): - """Check if human intervention is active.""" + """ + Check if human intervention is active based on keyboard toggle. + + Returns: + Boolean indicating whether intervention mode is active. + """ return self.keyboard_events["human_intervention_step"] class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): - """Wrapper with automatic intervention based on error thresholds.""" + """ + Wrapper with automatic intervention based on error thresholds. + + This wrapper monitors the error between leader and follower positions + and automatically triggers intervention when error exceeds thresholds. + """ def __init__( self, @@ -875,6 +1438,17 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): release_threshold=0.01, queue_size=10, ): + """ + Initialize the automatic intervention wrapper. + + Args: + env: The environment to wrap. + ee_action_space_params: Parameters defining the end-effector action space. + use_gripper: Whether to include gripper control. + intervention_threshold: Error threshold to trigger intervention. + release_threshold: Error threshold to release intervention. + queue_size: Number of error measurements to track for smoothing. + """ super().__init__(env, ee_action_space_params=ee_action_space_params, use_gripper=use_gripper) # Error tracking parameters @@ -890,7 +1464,16 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): self.start_time = time.perf_counter() def _check_intervention(self): - """Determine if intervention should occur based on leader-follower error.""" + """ + Determine if intervention should occur based on leader-follower error. + + This method monitors the error rate between leader and follower positions + and automatically triggers intervention when the error rate exceeds + the intervention threshold, releasing when it falls below the release threshold. + + Returns: + Boolean indicating whether intervention should be active. + """ # Skip intervention logic for the first few steps to collect data if time.perf_counter() - self.start_time < 1.0: # Wait 1 second before enabling return False @@ -932,7 +1515,15 @@ class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): return self.is_intervention_active def reset(self, **kwargs): - """Reset error tracking on environment reset.""" + """ + Reset error tracking on environment reset. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ self.error_queue.clear() self.error_over_time_queue.clear() self.previous_error = 0.0 @@ -962,15 +1553,14 @@ class GamepadControlWrapper(gym.Wrapper): """ Initialize the gamepad controller wrapper. - cfg. - env: The environment to wrap - x_step_size: Base movement step size for X axis in meters - y_step_size: Base movement step size for Y axis in meters - z_step_size: Base movement step size for Z axis in meters - vendor_id: USB vendor ID of the gamepad (default: Logitech) - product_id: USB product ID of the gamepad (default: RumblePad 2) - auto_reset: Whether to auto reset the environment when episode ends - input_threshold: Minimum movement delta to consider as active input + Args: + env: The environment to wrap. + x_step_size: Base movement step size for X axis in meters. + y_step_size: Base movement step size for Y axis in meters. + z_step_size: Base movement step size for Z axis in meters. + use_gripper: Whether to include gripper control. + auto_reset: Whether to auto reset the environment when episode ends. + input_threshold: Minimum movement delta to consider as active input. """ super().__init__(env) from lerobot.scripts.server.end_effector_control_utils import ( @@ -1011,7 +1601,12 @@ class GamepadControlWrapper(gym.Wrapper): Get the current action from the gamepad if any input is active. Returns: - Tuple of (is_active, action, terminate_episode, success) + Tuple containing: + - is_active: Whether gamepad input is active + - action: The action derived from gamepad input + - terminate_episode: Whether episode termination was requested + - success: Whether episode success was signaled + - rerecord_episode: Whether episode rerecording was requested """ # Update the controller to get fresh inputs self.controller.update() @@ -1052,11 +1647,11 @@ class GamepadControlWrapper(gym.Wrapper): """ Step the environment, using gamepad input to override actions when active. - cfg. - action: Original action from agent + Args: + action: Original action from agent. Returns: - observation, reward, terminated, truncated, info + Tuple of (observation, reward, terminated, truncated, info). """ # Get gamepad state and action ( @@ -1104,7 +1699,12 @@ class GamepadControlWrapper(gym.Wrapper): return obs, reward, terminated, truncated, info def close(self): - """Clean up resources when environment closes.""" + """ + Clean up resources when environment closes. + + Returns: + Result of closing the wrapped environment. + """ # Stop the controller if hasattr(self, "controller"): self.controller.stop() @@ -1113,70 +1713,6 @@ class GamepadControlWrapper(gym.Wrapper): return self.env.close() -class TorchBox(gym.spaces.Box): - """A version of gym.spaces.Box that handles PyTorch tensors. - - This class extends gym.spaces.Box to work with PyTorch tensors, - providing compatibility between NumPy arrays and PyTorch tensors. - """ - - def __init__( - self, - low: float | Sequence[float] | np.ndarray, - high: float | Sequence[float] | np.ndarray, - shape: Sequence[int] | None = None, - np_dtype: np.dtype | type = np.float32, - torch_dtype: torch.dtype = torch.float32, - device: str = "cpu", - seed: int | np.random.Generator | None = None, - ) -> None: - super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) - self.torch_dtype = torch_dtype - self.device = device - - def sample(self) -> torch.Tensor: - arr = super().sample() - return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) - - def contains(self, x: torch.Tensor) -> bool: - # Move to CPU/numpy and cast to the internal dtype - arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) - return super().contains(arr) - - def seed(self, seed: int | np.random.Generator | None = None): - super().seed(seed) - return [seed] - - def __repr__(self) -> str: - return ( - f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " - f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" - ) - - -class TorchActionWrapper(gym.Wrapper): - """ - The goal of this wrapper is to change the action_space.sample() - to torch tensors. - """ - - def __init__(self, env: gym.Env, device: str): - super().__init__(env) - self.action_space = TorchBox( - low=env.action_space.low, - high=env.action_space.high, - shape=env.action_space.shape, - torch_dtype=torch.float32, - device=torch.device("cpu"), - ) - - def step(self, action: torch.Tensor): - if action.dim() == 2: - action = action.squeeze(0) - action = action.detach().cpu().numpy() - return self.env.step(action) - - ########################################################### # Factory functions ########################################################### @@ -1186,13 +1722,14 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. - cfg. - robot: Robot instance to control - reward_classifier: Classifier model for computing rewards - cfg: Configuration object containing environment parameters + This function builds a robot environment with all necessary wrappers + based on the provided configuration. + + Args: + cfg: Configuration object containing environment parameters. Returns: - A vectorized gym environment with all the necessary wrappers applied. + A vectorized gym environment with all necessary wrappers applied. """ robot = make_robot_from_config(cfg.robot) # Create base environment @@ -1276,10 +1813,10 @@ def init_reward_classifier(cfg): Load a reward classifier policy from a pretrained path if configured. Args: - cfg: The environment configuration containing classifier paths + cfg: The environment configuration containing classifier paths. Returns: - The loaded classifier model or None if not configured + The loaded classifier model or None if not configured. """ if cfg.reward_classifier_pretrained_path is None: return None @@ -1306,20 +1843,26 @@ def init_reward_classifier(cfg): ########################################################### -def record_dataset(env, policy, cfg): +def record_dataset(env, policy, cfg, success_collection_steps=15): """ Record a dataset of robot interactions using either a policy or teleop. - cfg. - env: The environment to record from - repo_id: Repository ID for dataset storage - root: Local root directory for dataset (optional) - num_episodes: Number of episodes to record - control_time_s: Maximum episode length in seconds - fps: Frames per second for recording - push_to_hub: Whether to push dataset to Hugging Face Hub - task_description: Description of the task being recorded - policy: Optional policy to generate actions (if None, uses teleop) + This function runs episodes in the environment and records the observations, + actions, and results for dataset creation. + + Args: + env: The environment to record from. + policy: Optional policy to generate actions (if None, uses teleop). + cfg: Configuration object containing recording parameters like: + - repo_id: Repository ID for dataset storage + - dataset_root: Local root directory for dataset + - num_episodes: Number of episodes to record + - fps: Frames per second for recording + - push_to_hub: Whether to push dataset to Hugging Face Hub + - task: Name/description of the task being recorded + success_collection_steps: Number of additional steps to continue recording after + a success (reward=1) is detected. This helps collect + more positive examples for reward classifier training. """ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -1370,6 +1913,10 @@ def record_dataset(env, policy, cfg): start_episode_t = time.perf_counter() log_say(f"Recording episode {episode_index}", play_sounds=True) + # Track success state collection + success_detected = False + success_steps_collected = 0 + # Run episode steps while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: start_loop_t = time.perf_counter() @@ -1391,12 +1938,29 @@ def record_dataset(env, policy, cfg): } # Process observation for dataset - obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} - # Add frame to dataset - frame = {**obs, **recorded_action} - frame["next.reward"] = np.array([reward], dtype=np.float32) - frame["next.done"] = np.array([terminated or truncated], dtype=bool) + # Check if we've just detected success + if reward == 1.0 and not success_detected: + success_detected = True + logging.info("Success detected! Collecting additional success states.") + + # Add frame to dataset - continue marking as success even during extra collection steps + frame = {**obs_processed, **recorded_action} + + # If we're in the success collection phase, keep marking rewards as 1.0 + if success_detected: + frame["next.reward"] = np.array([1.0], dtype=np.float32) + else: + frame["next.reward"] = np.array([reward], dtype=np.float32) + + # Only mark as done if we're truly done (reached end or collected enough success states) + really_done = terminated or truncated + if success_detected: + success_steps_collected += 1 + really_done = success_steps_collected >= success_collection_steps + + frame["next.done"] = np.array([really_done], dtype=bool) frame["task"] = cfg.task dataset.add_frame(frame) @@ -1405,7 +1969,13 @@ def record_dataset(env, policy, cfg): dt_s = time.perf_counter() - start_loop_t busy_wait(1 / cfg.fps - dt_s) - if terminated or truncated: + # Check if we should end the episode + if (terminated or truncated) and not success_detected: + # Regular termination without success + break + elif success_detected and success_steps_collected >= success_collection_steps: + # We've collected enough success states + logging.info(f"Collected {success_steps_collected} additional success states") break # Handle episode recording @@ -1424,6 +1994,19 @@ def record_dataset(env, policy, cfg): def replay_episode(env, cfg): + """ + Replay a recorded episode in the environment. + + This function loads actions from a previously recorded episode + and executes them in the environment. + + Args: + env: The environment to replay in. + cfg: Configuration object containing replay parameters: + - repo_id: Repository ID for dataset + - dataset_root: Local root directory for dataset + - episode: Episode ID to replay + """ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) @@ -1443,6 +2026,16 @@ def replay_episode(env, cfg): @parser.wrap() def main(cfg: EnvConfig): + """ + Main entry point for the robot environment script. + + This function runs the robot environment in one of several modes + based on the provided configuration. + + Args: + cfg: Configuration object defining the run parameters, + including mode (record, replay, random) and other settings. + """ env = make_robot_env(cfg) if cfg.mode == "record": @@ -1454,10 +2047,12 @@ def main(cfg: EnvConfig): policy.to(cfg.device) policy.eval() + # Get success_collection_steps from config or default to 15 record_dataset( env, policy=policy, cfg=cfg, + success_collection_steps=15, ) exit() diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index dc03425e..a23111be 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -40,13 +40,13 @@ def test_binary_classifier_with_default_params(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 224, 224)), + "observation.image": torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), } images, labels = classifier.extract_images_and_labels(input) assert len(images) == 1 - assert images[0].shape == torch.Size([batch_size, 3, 224, 224]) + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) assert labels.shape == torch.Size([batch_size]) output = classifier.predict(images) @@ -56,7 +56,7 @@ def test_binary_classifier_with_default_params(): assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" assert output.probabilities.shape == torch.Size([batch_size]) assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" - assert output.hidden_states.shape == torch.Size([batch_size, 512]) + assert output.hidden_states.shape == torch.Size([batch_size, 256]) assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" @@ -79,13 +79,13 @@ def test_multiclass_classifier(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 224, 224)), + "observation.image": torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.rand((batch_size, num_classes)), } images, labels = classifier.extract_images_and_labels(input) assert len(images) == 1 - assert images[0].shape == torch.Size([batch_size, 3, 224, 224]) + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) assert labels.shape == torch.Size([batch_size, num_classes]) output = classifier.predict(images) @@ -95,7 +95,7 @@ def test_multiclass_classifier(): assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" assert output.probabilities.shape == torch.Size([batch_size, num_classes]) assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" - assert output.hidden_states.shape == torch.Size([batch_size, 512]) + assert output.hidden_states.shape == torch.Size([batch_size, 256]) assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"