[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@@ -172,7 +172,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
@@ -186,7 +188,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
}
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
@@ -452,7 +456,9 @@ def test_act_temporal_ensembler():
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
-1
)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
@@ -475,7 +481,8 @@ def test_act_temporal_ensembler():
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
/ weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)