Make sure policies don't mutate the batch (#323)

This commit is contained in:
Alexander Soare
2024-07-22 20:38:33 +01:00
committed by GitHub
parent 0b21210d72
commit abbb1d2367
6 changed files with 27 additions and 5 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from copy import deepcopy
from pathlib import Path
import einops
@@ -161,8 +162,13 @@ def test_policy(env_name, policy_name, extra_overrides):
for key in batch:
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) for k in batch
), "Batch values are not the same after a forward pass."
# reset the policy and environment
policy.reset()
@@ -174,9 +180,16 @@ def test_policy(env_name, policy_name, extra_overrides):
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
# get the next action for the environment
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(
observation_
), "Observation batch keys are not the same after a forward pass."
assert all(
torch.equal(observation[k], observation_[k]) for k in observation
), "Observation batch values are not the same after a forward pass."
# Test step through policy
env.step(action)