forked from zarr-developers/VirtualiZarr
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharray_api.py
More file actions
338 lines (261 loc) · 11.4 KB
/
array_api.py
File metadata and controls
338 lines (261 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import itertools
from typing import TYPE_CHECKING, Any, Callable, Union, cast
import numpy as np
from virtualizarr.utils import determine_chunk_grid_shape
from .manifest import ChunkManifest
from .utils import (
check_combinable_zarr_arrays,
check_no_partial_chunks_on_concat_axis,
check_same_ndims,
check_same_shapes,
check_same_shapes_except_on_concat_axis,
copy_and_replace_metadata,
)
if TYPE_CHECKING:
from .array import ManifestArray
MANIFESTARRAY_HANDLED_ARRAY_FUNCTIONS: dict[
str, Callable
] = {} # populated by the @implements decorators below
def implements(numpy_function):
"""Register an __array_function__ implementation for ManifestArray objects."""
def decorator(func):
MANIFESTARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
return func
return decorator
@implements(np.result_type)
def result_type(*arrays_and_dtypes: Union["ManifestArray", np.dtype]) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
from virtualizarr.manifests.array import ManifestArray
dtypes = (
obj.dtype if isinstance(obj, ManifestArray) else np.dtype(obj)
for obj in arrays_and_dtypes
)
first_dtype, *other_dtypes = dtypes
unique_dtypes = set(dtypes)
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError(
f"Cannot combine arrays with inconsistent dtypes, but got {len(unique_dtypes)} distinct dtypes: {unique_dtypes}"
)
return first_dtype
@implements(np.concatenate)
def concatenate(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
/,
*,
axis: int | None = 0,
) -> "ManifestArray":
"""
Concatenate ManifestArrays by merging their chunk manifests.
The signature of this function is array API compliant, so that it can be called by `xarray.concat`.
"""
from .array import ManifestArray
if axis is None:
raise NotImplementedError(
"If axis=None the array API requires flattening, which is a reshape, which can't be implemented on a ManifestArray."
)
elif not isinstance(axis, int):
raise TypeError()
# ensure dtypes, shapes, codecs etc. are consistent
check_combinable_zarr_arrays(arrays)
check_same_ndims([arr.ndim for arr in arrays])
# Ensure we handle axis being passed as a negative integer
first_arr = arrays[0]
if axis < 0:
axis = axis % first_arr.ndim
arr_shapes = [arr.shape for arr in arrays]
arr_chunks = [arr.chunks for arr in arrays]
check_same_shapes_except_on_concat_axis(arr_shapes, axis)
check_no_partial_chunks_on_concat_axis(arr_shapes, arr_chunks, axis)
# find what new array shape must be
new_length_along_concat_axis = sum([shape[axis] for shape in arr_shapes])
first_shape, *_ = arr_shapes
new_shape = list(first_shape)
new_shape[axis] = new_length_along_concat_axis
# do concatenation of entries in manifest
concatenated_manifest = _concat_manifests(
[arr.manifest for arr in arrays], axis=axis
)
new_metadata = copy_and_replace_metadata(
old_metadata=first_arr.metadata, new_shape=new_shape
)
return ManifestArray(chunkmanifest=concatenated_manifest, metadata=new_metadata)
@implements(np.stack)
def stack(
arrays: tuple["ManifestArray", ...] | list["ManifestArray"],
/,
*,
axis: int = 0,
) -> "ManifestArray":
"""
Stack ManifestArrays by merging their chunk manifests.
The signature of this function is array API compliant, so that it can be called by `xarray.stack`.
"""
from .array import ManifestArray
if not isinstance(axis, int):
raise TypeError()
# ensure dtypes, shapes, codecs etc. are consistent
check_combinable_zarr_arrays(arrays)
check_same_ndims([arr.ndim for arr in arrays])
arr_shapes = [arr.shape for arr in arrays]
check_same_shapes(arr_shapes)
# Ensure we handle axis being passed as a negative integer
first_arr = arrays[0]
if axis < 0:
axis = axis % first_arr.ndim
# find what new array shape must be
length_along_new_stacked_axis = len(arrays)
first_shape, *_ = arr_shapes
new_shape = list(first_shape)
new_shape.insert(axis, length_along_new_stacked_axis)
# do stacking of entries in manifest
stacked_manifest = _stack_manifests([arr.manifest for arr in arrays], axis=axis)
# chunk shape has changed because a length-1 axis has been inserted
old_chunks = first_arr.chunks
new_chunks = list(old_chunks)
new_chunks.insert(axis, 1)
new_metadata = copy_and_replace_metadata(
old_metadata=first_arr.metadata, new_shape=new_shape, new_chunks=new_chunks
)
return ManifestArray(chunkmanifest=stacked_manifest, metadata=new_metadata)
@implements(np.expand_dims)
def expand_dims(x: "ManifestArray", /, *, axis: int = 0) -> "ManifestArray":
"""Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis."""
# this is just a special case of stacking
return stack([x], axis=axis)
@implements(np.broadcast_to)
def broadcast_to(x: "ManifestArray", /, shape: tuple[int, ...]) -> "ManifestArray":
"""
Broadcasts a ManifestArray to a specified shape, by either adjusting chunk keys or copying chunk manifest entries.
"""
from .array import ManifestArray
new_shape = shape
# check its actually possible to broadcast to this new shape
mutually_broadcastable_shape = np.broadcast_shapes(x.shape, new_shape)
if mutually_broadcastable_shape != new_shape:
# we're not trying to broadcast both shapes to a third shape
raise ValueError(
f"array of shape {x.shape} cannot be broadcast to shape {new_shape}"
)
# new chunk_shape is old chunk_shape with singleton dimensions prepended
# (chunk shape can never change by more than adding length-1 axes because each chunk represents a fixed number of array elements)
old_chunk_shape = x.chunks
new_chunk_shape = _prepend_singleton_dimensions(
old_chunk_shape, ndim=len(new_shape)
)
# find new chunk grid shape by dividing new array shape by new chunk shape
new_chunk_grid_shape = determine_chunk_grid_shape(new_shape, new_chunk_shape)
# do broadcasting of entries in manifest
broadcasted_manifest = _broadcast_manifest(x.manifest, shape=new_chunk_grid_shape)
new_metadata = copy_and_replace_metadata(
old_metadata=x.metadata,
new_shape=list(new_shape),
new_chunks=list(new_chunk_shape),
)
return ManifestArray(chunkmanifest=broadcasted_manifest, metadata=new_metadata)
def _concat_manifests(manifests: list[ChunkManifest], axis: int) -> ChunkManifest:
"""Concatenate manifests along an existing axis."""
concatenated_paths = cast(
np.ndarray[Any, np.dtypes.StringDType],
np.concatenate([m._paths for m in manifests], axis=axis),
)
concatenated_offsets = np.concatenate([m._offsets for m in manifests], axis=axis)
concatenated_lengths = np.concatenate([m._lengths for m in manifests], axis=axis)
# merge inlined chunk dicts with index shifting along the concat axis
concatenated_inlined: dict[tuple[int, ...], bytes] = {}
grid_offset = 0
for m in manifests:
for key, data in m._inlined.items():
shifted = list(key)
shifted[axis] += grid_offset
concatenated_inlined[tuple(shifted)] = data
grid_offset += m._paths.shape[axis]
return ChunkManifest.from_arrays(
paths=concatenated_paths,
offsets=concatenated_offsets,
lengths=concatenated_lengths,
validate_paths=False,
inlined=concatenated_inlined if concatenated_inlined else None,
)
def _stack_manifests(manifests: list[ChunkManifest], axis: int) -> ChunkManifest:
"""Stack manifests along a new axis."""
stacked_paths = cast(
np.ndarray[Any, np.dtypes.StringDType],
np.stack([m._paths for m in manifests], axis=axis),
)
stacked_offsets = np.stack([m._offsets for m in manifests], axis=axis)
stacked_lengths = np.stack([m._lengths for m in manifests], axis=axis)
# merge inlined chunk dicts, inserting the new stacked axis
stacked_inlined: dict[tuple[int, ...], bytes] = {}
for i, m in enumerate(manifests):
for key, data in m._inlined.items():
shifted = list(key)
shifted.insert(axis, i)
stacked_inlined[tuple(shifted)] = data
return ChunkManifest.from_arrays(
paths=stacked_paths,
offsets=stacked_offsets,
lengths=stacked_lengths,
validate_paths=False,
inlined=stacked_inlined if stacked_inlined else None,
)
def _broadcast_manifest(
manifest: ChunkManifest, shape: tuple[int, ...]
) -> ChunkManifest:
"""Broadcast manifest to a new chunk grid shape."""
broadcasted_paths = cast(
np.ndarray[Any, np.dtypes.StringDType],
np.broadcast_to(manifest._paths, shape=shape),
)
broadcasted_offsets = np.broadcast_to(manifest._offsets, shape=shape)
broadcasted_lengths = np.broadcast_to(manifest._lengths, shape=shape)
# broadcast inlined chunks: prepend singleton dims to each key, then replicate
# the entry across every target position along any axis that was size 1 in the
# source (matching np.broadcast_to semantics for the paths/offsets/lengths arrays).
broadcasted_inlined: dict[tuple[int, ...], bytes] = {}
if manifest._inlined:
n_prepended = len(shape) - manifest._paths.ndim
source_shape_padded = (1,) * n_prepended + manifest._paths.shape
for key, data in manifest._inlined.items():
padded_key = (0,) * n_prepended + key
axis_ranges = [
range(shape[i]) if source_shape_padded[i] == 1 else (padded_key[i],)
for i in range(len(shape))
]
for target_key in itertools.product(*axis_ranges):
broadcasted_inlined[target_key] = data
return ChunkManifest.from_arrays(
paths=broadcasted_paths,
offsets=broadcasted_offsets,
lengths=broadcasted_lengths,
validate_paths=False,
inlined=broadcasted_inlined if broadcasted_inlined else None,
)
def _prepend_singleton_dimensions(shape: tuple[int, ...], ndim: int) -> tuple[int, ...]:
"""Prepend as many new length-1 axes to shape as necessary such that the result has ndim number of axes."""
n_prepended_dims = ndim - len(shape)
return tuple([1] * n_prepended_dims + list(shape))
# TODO broadcast_arrays, squeeze, permute_dims
@implements(np.full_like)
def full_like(
x: "ManifestArray", /, fill_value: bool, *, dtype: np.dtype | None
) -> np.ndarray:
"""
Returns a new array filled with fill_value and having the same shape as an input array x.
Returns a numpy array instead of a ManifestArray.
Only implemented to get past some checks deep inside xarray, see https://github.com/zarr-developers/VirtualiZarr/issues/29.
"""
return np.full(
shape=x.shape,
fill_value=fill_value,
dtype=dtype if dtype is not None else x.dtype,
)
@implements(np.isnan)
def isnan(x: "ManifestArray", /) -> np.ndarray:
"""
Returns a numpy array of all False.
Only implemented to get past some checks deep inside xarray, see https://github.com/zarr-developers/VirtualiZarr/issues/29.
"""
return _isnan(x.shape)
def _isnan(shape: tuple):
return np.full(shape=shape, fill_value=False, dtype=np.dtype(bool))