[HIL-SERL]Remove overstrict pre-commit modifications (#1028)

This commit is contained in:
Adil Zouitine
2025-04-24 13:48:52 +02:00
committed by GitHub
parent 671ac3411f
commit c58b504a9e
47 changed files with 163 additions and 757 deletions

View File

@@ -38,12 +38,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
FeetechMotorsBus,
)
return (
FeetechMotorsBusConfig,
FeetechMotorsBus,
MODEL_BAUDRATE_TABLE,
SCS_SERIES_BAUDRATE_TABLE,
)
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
elif brand == "dynamixel":
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
@@ -53,12 +48,7 @@ def get_motor_bus_cls(brand: str) -> tuple:
DynamixelMotorsBus,
)
return (
DynamixelMotorsBusConfig,
DynamixelMotorsBus,
MODEL_BAUDRATE_TABLE,
X_SERIES_BAUDRATE_TABLE,
)
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
else:
raise ValueError(
@@ -174,25 +164,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
)
args = parser.parse_args()

View File

@@ -149,11 +149,7 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@@ -203,7 +199,6 @@ def record(
run_compute_stats: bool = True,
) -> LeRobotDataset:
# Load pretrained policy
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -246,11 +241,7 @@ def record(
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {
"dtype": "video",
"names": ["channels", "height", "width"],
"shape": shape,
}
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
for key, obs_key in state_keys_dict.items():
features[key] = {
@@ -259,11 +250,7 @@ def record(
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@@ -374,12 +361,7 @@ def record(
def replay(
env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
):
env = env()
@@ -426,10 +408,7 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument(
"--root",
@@ -507,19 +486,9 @@ if __name__ == "__main__":
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument(
"--root",

View File

@@ -293,8 +293,7 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
@@ -414,11 +413,7 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@@ -486,10 +481,7 @@ def eval_main(cfg: EvalPipelineConfig):
)
policy.eval()
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,

View File

@@ -196,11 +196,7 @@ def train(cfg: TrainPipelineConfig):
}
train_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
@@ -271,11 +267,7 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")

View File

@@ -81,11 +81,7 @@ def run_server(
static_folder: Path,
template_folder: Path,
):
app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
@@ -201,8 +197,7 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
timeout=5,
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@@ -287,8 +282,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@@ -337,8 +331,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
timeout=5,
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()