Fix vmap example in docs (#1556)

This commit is contained in:
Chris Offner
2024-11-03 01:44:14 +01:00
committed by GitHub
parent 42533931fa
commit 46d8b16ab4
2 changed files with 3 additions and 3 deletions

View File

@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify