Update pre-commits (#733)

This commit is contained in:
Simon Alibert
2025-02-15 15:51:17 +01:00
committed by GitHub
parent 2cb0bf5d41
commit c4c2ce04e7
16 changed files with 69 additions and 69 deletions

View File

@@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(
observation_
), "Observation batch keys are not the same after a forward pass."
assert all(
torch.equal(observation[k], observation_[k]) for k in observation
), "Observation batch values are not the same after a forward pass."
assert set(observation) == set(observation_), (
"Observation batch keys are not the same after a forward pass."
)
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
"Observation batch values are not the same after a forward pass."
)
# Test step through policy
env.step(action)