620 lines
20 KiB
Python
620 lines
20 KiB
Python
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
|
|
|
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import numbers
|
|
import os
|
|
from functools import cached_property
|
|
|
|
import numcodecs
|
|
import numpy as np
|
|
import zarr
|
|
|
|
|
|
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
|
assert len(shape) == len(chunks)
|
|
for c in chunks:
|
|
assert isinstance(c, numbers.Integral)
|
|
assert c > 0
|
|
|
|
|
|
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
|
old_arr = group[name]
|
|
if chunks is None:
|
|
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
|
check_chunks_compatible(chunks, old_arr.shape)
|
|
|
|
if compressor is None:
|
|
compressor = old_arr.compressor
|
|
|
|
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
|
# no change
|
|
return old_arr
|
|
|
|
# rechunk recompress
|
|
group.move(name, tmp_key)
|
|
old_arr = group[tmp_key]
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
source=old_arr,
|
|
dest=group,
|
|
name=name,
|
|
chunks=chunks,
|
|
compressor=compressor,
|
|
)
|
|
del group[tmp_key]
|
|
arr = group[name]
|
|
return arr
|
|
|
|
|
|
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
|
"""
|
|
Common shapes
|
|
T,D
|
|
T,N,D
|
|
T,H,W,C
|
|
T,N,H,W,C
|
|
"""
|
|
itemsize = np.dtype(dtype).itemsize
|
|
# reversed
|
|
rshape = list(shape[::-1])
|
|
if max_chunk_length is not None:
|
|
rshape[-1] = int(max_chunk_length)
|
|
split_idx = len(shape) - 1
|
|
for i in range(len(shape) - 1):
|
|
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
|
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
|
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
|
split_idx = i
|
|
|
|
rchunks = rshape[:split_idx]
|
|
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
|
this_max_chunk_length = rshape[split_idx]
|
|
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
|
rchunks.append(next_chunk_length)
|
|
len_diff = len(shape) - len(rchunks)
|
|
rchunks.extend([1] * len_diff)
|
|
chunks = tuple(rchunks[::-1])
|
|
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
|
return chunks
|
|
|
|
|
|
class ReplayBuffer:
|
|
"""
|
|
Zarr-based temporal datastructure.
|
|
Assumes first dimension to be time. Only chunk in time dimension.
|
|
"""
|
|
|
|
def __init__(self, root: zarr.Group | dict[str, dict]):
|
|
"""
|
|
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
|
"""
|
|
assert "data" in root
|
|
assert "meta" in root
|
|
assert "episode_ends" in root["meta"]
|
|
for value in root["data"].values():
|
|
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
|
self.root = root
|
|
|
|
# ============= create constructors ===============
|
|
@classmethod
|
|
def create_empty_zarr(cls, storage=None, root=None):
|
|
if root is None:
|
|
if storage is None:
|
|
storage = zarr.MemoryStore()
|
|
root = zarr.group(store=storage)
|
|
root.require_group("data", overwrite=False)
|
|
meta = root.require_group("meta", overwrite=False)
|
|
if "episode_ends" not in meta:
|
|
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
|
return cls(root=root)
|
|
|
|
@classmethod
|
|
def create_empty_numpy(cls):
|
|
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
|
return cls(root=root)
|
|
|
|
@classmethod
|
|
def create_from_group(cls, group, **kwargs):
|
|
if "data" not in group:
|
|
# create from stratch
|
|
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
|
else:
|
|
# already exist
|
|
buffer = cls(root=group, **kwargs)
|
|
return buffer
|
|
|
|
@classmethod
|
|
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
|
"""
|
|
Open a on-disk zarr directly (for dataset larger than memory).
|
|
Slower.
|
|
"""
|
|
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
|
return cls.create_from_group(group, **kwargs)
|
|
|
|
# ============= copy constructors ===============
|
|
@classmethod
|
|
def copy_from_store(
|
|
cls,
|
|
src_store,
|
|
store=None,
|
|
keys=None,
|
|
chunks: dict[str, tuple] | None = None,
|
|
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
|
if_exists="replace",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Load to memory.
|
|
"""
|
|
src_root = zarr.group(src_store)
|
|
if chunks is None:
|
|
chunks = {}
|
|
if compressors is None:
|
|
compressors = {}
|
|
root = None
|
|
if store is None:
|
|
# numpy backend
|
|
meta = {}
|
|
for key, value in src_root["meta"].items():
|
|
if len(value.shape) == 0:
|
|
meta[key] = np.array(value)
|
|
else:
|
|
meta[key] = value[:]
|
|
|
|
if keys is None:
|
|
keys = src_root["data"].keys()
|
|
data = {}
|
|
for key in keys:
|
|
arr = src_root["data"][key]
|
|
data[key] = arr[:]
|
|
|
|
root = {"meta": meta, "data": data}
|
|
else:
|
|
root = zarr.group(store=store)
|
|
# copy without recompression
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
|
)
|
|
data_group = root.create_group("data", overwrite=True)
|
|
if keys is None:
|
|
keys = src_root["data"].keys()
|
|
for key in keys:
|
|
value = src_root["data"][key]
|
|
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
if cks == value.chunks and cpr == value.compressor:
|
|
# copy without recompression
|
|
this_path = "/data/" + key
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
source=src_store,
|
|
dest=store,
|
|
source_path=this_path,
|
|
dest_path=this_path,
|
|
if_exists=if_exists,
|
|
)
|
|
else:
|
|
# copy with recompression
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
source=value,
|
|
dest=data_group,
|
|
name=key,
|
|
chunks=cks,
|
|
compressor=cpr,
|
|
if_exists=if_exists,
|
|
)
|
|
buffer = cls(root=root)
|
|
return buffer
|
|
|
|
@classmethod
|
|
def copy_from_path(
|
|
cls,
|
|
zarr_path,
|
|
backend=None,
|
|
store=None,
|
|
keys=None,
|
|
chunks: dict[str, tuple] | None = None,
|
|
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
|
if_exists="replace",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Copy a on-disk zarr to in-memory compressed.
|
|
Recommended
|
|
"""
|
|
if chunks is None:
|
|
chunks = {}
|
|
if compressors is None:
|
|
compressors = {}
|
|
if backend == "numpy":
|
|
print("backend argument is deprecated!")
|
|
store = None
|
|
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
|
return cls.copy_from_store(
|
|
src_store=group.store,
|
|
store=store,
|
|
keys=keys,
|
|
chunks=chunks,
|
|
compressors=compressors,
|
|
if_exists=if_exists,
|
|
**kwargs,
|
|
)
|
|
|
|
# ============= save methods ===============
|
|
def save_to_store(
|
|
self,
|
|
store,
|
|
chunks: dict[str, tuple] | None = None,
|
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
if_exists="replace",
|
|
**kwargs,
|
|
):
|
|
root = zarr.group(store)
|
|
if chunks is None:
|
|
chunks = {}
|
|
if compressors is None:
|
|
compressors = {}
|
|
if self.backend == "zarr":
|
|
# recompression free copy
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
source=self.root.store,
|
|
dest=store,
|
|
source_path="/meta",
|
|
dest_path="/meta",
|
|
if_exists=if_exists,
|
|
)
|
|
else:
|
|
meta_group = root.create_group("meta", overwrite=True)
|
|
# save meta, no chunking
|
|
for key, value in self.root["meta"].items():
|
|
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
|
|
|
# save data, chunk
|
|
data_group = root.create_group("data", overwrite=True)
|
|
for key, value in self.root["data"].items():
|
|
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
if isinstance(value, zarr.Array):
|
|
if cks == value.chunks and cpr == value.compressor:
|
|
# copy without recompression
|
|
this_path = "/data/" + key
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
|
source=self.root.store,
|
|
dest=store,
|
|
source_path=this_path,
|
|
dest_path=this_path,
|
|
if_exists=if_exists,
|
|
)
|
|
else:
|
|
# copy with recompression
|
|
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
|
source=value,
|
|
dest=data_group,
|
|
name=key,
|
|
chunks=cks,
|
|
compressor=cpr,
|
|
if_exists=if_exists,
|
|
)
|
|
else:
|
|
# numpy
|
|
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
|
return store
|
|
|
|
def save_to_path(
|
|
self,
|
|
zarr_path,
|
|
chunks: dict[str, tuple] | None = None,
|
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
if_exists="replace",
|
|
**kwargs,
|
|
):
|
|
if chunks is None:
|
|
chunks = {}
|
|
if compressors is None:
|
|
compressors = {}
|
|
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
|
return self.save_to_store(
|
|
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
|
)
|
|
|
|
@staticmethod
|
|
def resolve_compressor(compressor="default"):
|
|
if compressor == "default":
|
|
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
|
elif compressor == "disk":
|
|
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
|
return compressor
|
|
|
|
@classmethod
|
|
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
|
# allows compressor to be explicitly set to None
|
|
cpr = "nil"
|
|
if isinstance(compressors, dict):
|
|
if key in compressors:
|
|
cpr = cls.resolve_compressor(compressors[key])
|
|
elif isinstance(array, zarr.Array):
|
|
cpr = array.compressor
|
|
else:
|
|
cpr = cls.resolve_compressor(compressors)
|
|
# backup default
|
|
if cpr == "nil":
|
|
cpr = cls.resolve_compressor("default")
|
|
return cpr
|
|
|
|
@classmethod
|
|
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
|
cks = None
|
|
if isinstance(chunks, dict):
|
|
if key in chunks:
|
|
cks = chunks[key]
|
|
elif isinstance(array, zarr.Array):
|
|
cks = array.chunks
|
|
elif isinstance(chunks, tuple):
|
|
cks = chunks
|
|
else:
|
|
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
|
# backup default
|
|
if cks is None:
|
|
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
|
# check
|
|
check_chunks_compatible(chunks=cks, shape=array.shape)
|
|
return cks
|
|
|
|
# ============= properties =================
|
|
@cached_property
|
|
def data(self):
|
|
return self.root["data"]
|
|
|
|
@cached_property
|
|
def meta(self):
|
|
return self.root["meta"]
|
|
|
|
def update_meta(self, data):
|
|
# sanitize data
|
|
np_data = {}
|
|
for key, value in data.items():
|
|
if isinstance(value, np.ndarray):
|
|
np_data[key] = value
|
|
else:
|
|
arr = np.array(value)
|
|
if arr.dtype == object:
|
|
raise TypeError(f"Invalid value type {type(value)}")
|
|
np_data[key] = arr
|
|
|
|
meta_group = self.meta
|
|
if self.backend == "zarr":
|
|
for key, value in np_data.items():
|
|
_ = meta_group.array(
|
|
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
|
)
|
|
else:
|
|
meta_group.update(np_data)
|
|
|
|
return meta_group
|
|
|
|
@property
|
|
def episode_ends(self):
|
|
return self.meta["episode_ends"]
|
|
|
|
def get_episode_idxs(self):
|
|
import numba
|
|
|
|
numba.jit(nopython=True)
|
|
|
|
def _get_episode_idxs(episode_ends):
|
|
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
|
for i in range(len(episode_ends)):
|
|
start = 0
|
|
if i > 0:
|
|
start = episode_ends[i - 1]
|
|
end = episode_ends[i]
|
|
for idx in range(start, end):
|
|
result[idx] = i
|
|
return result
|
|
|
|
return _get_episode_idxs(self.episode_ends)
|
|
|
|
@property
|
|
def backend(self):
|
|
backend = "numpy"
|
|
if isinstance(self.root, zarr.Group):
|
|
backend = "zarr"
|
|
return backend
|
|
|
|
# =========== dict-like API ==============
|
|
def __repr__(self) -> str:
|
|
if self.backend == "zarr":
|
|
return str(self.root.tree())
|
|
else:
|
|
return super().__repr__()
|
|
|
|
def keys(self):
|
|
return self.data.keys()
|
|
|
|
def values(self):
|
|
return self.data.values()
|
|
|
|
def items(self):
|
|
return self.data.items()
|
|
|
|
def __getitem__(self, key):
|
|
return self.data[key]
|
|
|
|
def __contains__(self, key):
|
|
return key in self.data
|
|
|
|
# =========== our API ==============
|
|
@property
|
|
def n_steps(self):
|
|
if len(self.episode_ends) == 0:
|
|
return 0
|
|
return self.episode_ends[-1]
|
|
|
|
@property
|
|
def n_episodes(self):
|
|
return len(self.episode_ends)
|
|
|
|
@property
|
|
def chunk_size(self):
|
|
if self.backend == "zarr":
|
|
return next(iter(self.data.arrays()))[-1].chunks[0]
|
|
return None
|
|
|
|
@property
|
|
def episode_lengths(self):
|
|
ends = self.episode_ends[:]
|
|
ends = np.insert(ends, 0, 0)
|
|
lengths = np.diff(ends)
|
|
return lengths
|
|
|
|
def add_episode(
|
|
self,
|
|
data: dict[str, np.ndarray],
|
|
chunks: dict[str, tuple] | None = None,
|
|
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
|
):
|
|
if chunks is None:
|
|
chunks = {}
|
|
if compressors is None:
|
|
compressors = {}
|
|
assert len(data) > 0
|
|
is_zarr = self.backend == "zarr"
|
|
|
|
curr_len = self.n_steps
|
|
episode_length = None
|
|
for value in data.values():
|
|
assert len(value.shape) >= 1
|
|
if episode_length is None:
|
|
episode_length = len(value)
|
|
else:
|
|
assert episode_length == len(value)
|
|
new_len = curr_len + episode_length
|
|
|
|
for key, value in data.items():
|
|
new_shape = (new_len,) + value.shape[1:]
|
|
# create array
|
|
if key not in self.data:
|
|
if is_zarr:
|
|
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
|
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
|
arr = self.data.zeros(
|
|
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
|
)
|
|
else:
|
|
# copy data to prevent modify
|
|
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
|
self.data[key] = arr
|
|
else:
|
|
arr = self.data[key]
|
|
assert value.shape[1:] == arr.shape[1:]
|
|
# same method for both zarr and numpy
|
|
if is_zarr:
|
|
arr.resize(new_shape)
|
|
else:
|
|
arr.resize(new_shape, refcheck=False)
|
|
# copy data
|
|
arr[-value.shape[0] :] = value
|
|
|
|
# append to episode ends
|
|
episode_ends = self.episode_ends
|
|
if is_zarr:
|
|
episode_ends.resize(episode_ends.shape[0] + 1)
|
|
else:
|
|
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
|
episode_ends[-1] = new_len
|
|
|
|
# rechunk
|
|
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
|
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
|
|
|
def drop_episode(self):
|
|
is_zarr = self.backend == "zarr"
|
|
episode_ends = self.episode_ends[:].copy()
|
|
assert len(episode_ends) > 0
|
|
start_idx = 0
|
|
if len(episode_ends) > 1:
|
|
start_idx = episode_ends[-2]
|
|
for value in self.data.values():
|
|
new_shape = (start_idx,) + value.shape[1:]
|
|
if is_zarr:
|
|
value.resize(new_shape)
|
|
else:
|
|
value.resize(new_shape, refcheck=False)
|
|
if is_zarr:
|
|
self.episode_ends.resize(len(episode_ends) - 1)
|
|
else:
|
|
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
|
|
|
def pop_episode(self):
|
|
assert self.n_episodes > 0
|
|
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
|
self.drop_episode()
|
|
return episode
|
|
|
|
def extend(self, data):
|
|
self.add_episode(data)
|
|
|
|
def get_episode(self, idx, copy=False):
|
|
idx = list(range(len(self.episode_ends)))[idx]
|
|
start_idx = 0
|
|
if idx > 0:
|
|
start_idx = self.episode_ends[idx - 1]
|
|
end_idx = self.episode_ends[idx]
|
|
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
|
return result
|
|
|
|
def get_episode_slice(self, idx):
|
|
start_idx = 0
|
|
if idx > 0:
|
|
start_idx = self.episode_ends[idx - 1]
|
|
end_idx = self.episode_ends[idx]
|
|
return slice(start_idx, end_idx)
|
|
|
|
def get_steps_slice(self, start, stop, step=None, copy=False):
|
|
_slice = slice(start, stop, step)
|
|
|
|
result = {}
|
|
for key, value in self.data.items():
|
|
x = value[_slice]
|
|
if copy and isinstance(value, np.ndarray):
|
|
x = x.copy()
|
|
result[key] = x
|
|
return result
|
|
|
|
# =========== chunking =============
|
|
def get_chunks(self) -> dict:
|
|
assert self.backend == "zarr"
|
|
chunks = {}
|
|
for key, value in self.data.items():
|
|
chunks[key] = value.chunks
|
|
return chunks
|
|
|
|
def set_chunks(self, chunks: dict):
|
|
assert self.backend == "zarr"
|
|
for key, value in chunks.items():
|
|
if key in self.data:
|
|
arr = self.data[key]
|
|
if value != arr.chunks:
|
|
check_chunks_compatible(chunks=value, shape=arr.shape)
|
|
rechunk_recompress_array(self.data, key, chunks=value)
|
|
|
|
def get_compressors(self) -> dict:
|
|
assert self.backend == "zarr"
|
|
compressors = {}
|
|
for key, value in self.data.items():
|
|
compressors[key] = value.compressor
|
|
return compressors
|
|
|
|
def set_compressors(self, compressors: dict):
|
|
assert self.backend == "zarr"
|
|
for key, value in compressors.items():
|
|
if key in self.data:
|
|
arr = self.data[key]
|
|
compressor = self.resolve_compressor(value)
|
|
if compressor != arr.compressor:
|
|
rechunk_recompress_array(self.data, key, compressor=compressor)
|