pre-commit run -a
This commit is contained in:
@@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
mean_std = self._compute_or_load_mean_std(storage)
|
||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||
transform = NormalizeTransform(mean_std, in_keys=[
|
||||
("observation", "image"),
|
||||
("observation", "state"),
|
||||
("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
("action"),
|
||||
])
|
||||
transform = NormalizeTransform(
|
||||
mean_std,
|
||||
in_keys=[
|
||||
("observation", "image"),
|
||||
("observation", "state"),
|
||||
("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
@@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
@@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
action_mean = torch.zeros(batch["action"].shape[1])
|
||||
action_std = torch.zeros(batch["action"].shape[1])
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||
state_mean += batch["observation", "state"].mean(dim=0)
|
||||
action_mean += batch["action"].mean(dim=0)
|
||||
batch = rb.sample()
|
||||
@@ -302,25 +307,25 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||
image_std += (image_mean_batch - image_mean) ** 2
|
||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
mean_std = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None,:,None,None],
|
||||
("observation", "image", "std"): image_std[None,:,None,None],
|
||||
("observation", "state", "mean"): state_mean[None,:],
|
||||
("observation", "state", "std"): state_std[None,:],
|
||||
("action", "mean"): action_mean[None,:],
|
||||
("action", "std"): action_std[None,:],
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
@@ -3,11 +3,9 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from tensordict import TensorDictBase
|
||||
import tqdm
|
||||
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
@@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user