Build Custom Dask.Array Function¶
As discussed in the array design document
to create a
dask Array
object we need the following:
- A dask graph
- A name specifying a set of keys within that graph
- A
chunks
tuple giving chunk shape information - A NumPy dtype
Often dask.array
functions take other Array
objects as inputs along
with parameters, add tasks to a new dask dictionary, create a new chunks
tuple, and then construct and return a new Array
object. The hard parts
are invariably creating the right tasks and creating a new chunks
tuple.
Careful review of the array design document
is suggested.
Example eye¶
Consider this simple example with the eye
function.
from dask.base import tokenize
def eye(n, blocksize):
chunks = ((blocksize,) * n // blocksize,
(blocksize,) * n // blocksize)
name = 'eye-' + tokenize(n, blocksize) # unique identifier
dsk = {(name, i, j): (np.eye, blocksize)
if i == j else
(np.zeros, (blocksize, blocksize))
for i in range(n // blocksize)
for j in range(n // blocksize)}
dtype = np.eye(0).dtype # take dtype default from numpy
return Array(dsk, name, chunks, dtype)
This example is particularly simple because it doesn’t take any Array
objects as input.
Example diag¶
Consider the function diag
that takes a 1d vector and produces a 2d matrix
with the values of the vector along the diagonal. Consider the case where the
input is a 1d array with chunk sizes (2, 3, 4)
in the first dimension like
this:
[x_0, x_1], [x_2, x_3, x_4], [x_5, x_6, x_7, x_8]
We need to create a 2d matrix with chunks equal to ((2, 3, 4), (2, 3, 4))
where the ith block along the diagonal of the output is the result of calling
np.diag
on the ith
block of the input and all other blocks are zero.
from dask.base import tokenize
def diag(v):
"""Construct a diagonal array, with ``v`` on the diagonal."""
assert v.ndim == 1
chunks = (v.chunks[0], v.chunks[0]) # repeat chunks twice
name = 'diag-' + tokenize(v) # unique identifier
dsk = {(name, i, j): (np.diag, (v.name, i))
if i == j else
(np.zeros, (v.chunks[0][i], v.chunks[0][j]))
for i in range(len(v.chunks[0]))
for j in range(len(v.chunks[0]))}
dsk.update(v.dask) # include dask graph of the input
dtype = v.dtype # output has the same dtype as the input
return Array(dsk, name, chunks, dtype)
>>> x = da.arange(9, chunks=((2, 3, 4),))
>>> x
dask.array<arange-1, shape=(9,), chunks=((2, 3, 4)), dtype=int64>
>>> M = diag(x)
>>> M
dask.array<diag-2, shape=(9, 9), chunks=((2, 3, 4), (2, 3, 4)), dtype=int64>
>>> M.compute()
array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 2, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 6, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 7, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 8]])