Add review feedback
This commit is contained in:
@@ -46,6 +46,7 @@ repos:
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
# Exclude generated protobuf files
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.11
|
||||
|
||||
@@ -168,13 +168,6 @@ class VideoRecordConfig:
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
@@ -187,7 +180,7 @@ class EEActionSpaceConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
@@ -212,7 +205,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
@@ -225,9 +218,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
number_of_steps_after_success: int = (
|
||||
0 # For the reward classifier, to record more positive examples after a success
|
||||
)
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
@@ -266,7 +258,7 @@ class HILEnvConfig(EnvConfig):
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
|
||||
@@ -151,6 +151,7 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
@@ -337,7 +338,6 @@ class NormalizeBuffer(nn.Module):
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
@@ -386,7 +386,6 @@ class UnnormalizeBuffer(nn.Module):
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
@@ -115,8 +115,11 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
pid_str = f"[PID: {os.getpid()}]" if display_pid else ""
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -131,7 +134,7 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# File handler
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
@@ -247,11 +250,23 @@ class TimerManager:
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> timer = TimerManager("Policy", log=False)
|
||||
>>> for _ in range(3):
|
||||
... with timer:
|
||||
... time.sleep(0.01)
|
||||
>>> print(timer.last, timer.fps_avg, timer.percentile(90))
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -149,12 +149,8 @@ class WandBLogger:
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
@@ -34,7 +34,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
@@ -124,10 +124,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with (
|
||||
open(save_directory / TRAIN_CONFIG_NAME, "w") as f,
|
||||
draccus.config_type("json"),
|
||||
):
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -720,7 +720,7 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvWrapperConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
@@ -781,7 +781,7 @@ if __name__ == "__main__":
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
|
||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user