pre-commit run -a

This commit is contained in:
Remi Cadene
2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
mean_std = self._compute_or_load_mean_std(storage)
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
transform = NormalizeTransform(mean_std, in_keys=[
("observation", "image"),
("observation", "state"),
("next", "observation", "image"),
("next", "observation", "state"),
("action"),
])
transform = NormalizeTransform(
mean_std,
in_keys=[
("observation", "image"),
("observation", "state"),
("next", "observation", "image"),
("next", "observation", "state"),
("action"),
],
)
if writer is None:
writer = ImmutableDatasetWriter()
@@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
@@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean = torch.zeros(batch["action"].shape[1])
action_std = torch.zeros(batch["action"].shape[1])
for i in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
state_mean += batch["observation", "state"].mean(dim=0)
action_mean += batch["action"].mean(dim=0)
batch = rb.sample()
@@ -302,25 +307,25 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)):
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
image_std += (image_mean_batch - image_mean) ** 2
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
if i < num_batch - 1:
batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch)
mean_std = TensorDict(
{
("observation", "image", "mean"): image_mean[None,:,None,None],
("observation", "image", "std"): image_std[None,:,None,None],
("observation", "state", "mean"): state_mean[None,:],
("observation", "state", "std"): state_std[None,:],
("action", "mean"): action_mean[None,:],
("action", "std"): action_std[None,:],
("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :],
("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :],
},
batch_size=[],
)

View File

@@ -3,11 +3,9 @@ import zipfile
from pathlib import Path
import requests
from tensordict import TensorDictBase
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
@@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True
else:
return False