From 46d8b16ab4d548ce539a927d72c1a6c42a6ec74c Mon Sep 17 00:00:00 2001 From: Chris Offner Date: Sun, 3 Nov 2024 01:44:14 +0100 Subject: [PATCH] Fix vmap example in docs (#1556) --- docs/src/usage/function_transforms.rst | 4 ++-- docs/src/usage/indexing.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/usage/function_transforms.rst b/docs/src/usage/function_transforms.rst index 9a15bbf1c..9769fceaa 100644 --- a/docs/src/usage/function_transforms.rst +++ b/docs/src/usage/function_transforms.rst @@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop: ys = mx.random.uniform(shape=(100, 4096)) def naive_add(xs, ys): - return [xs[i] + ys[:, i] for i in range(xs.shape[1])] + return [xs[i] + ys[:, i] for i in range(xs.shape[0])] Instead you can use :func:`vmap` to automatically vectorize the addition: @@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition: # Vectorize over the second dimension of x and the # first dimension of y - vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) + vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1)) The ``in_axes`` parameter can be used to specify which dimensions of the corresponding input to vectorize over. Similarly, use ``out_axes`` to specify diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index 62994a0fb..c74e357fa 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient. Indexing with boolean masks is something that MLX may support in the future. In -general, MLX has limited support for operations for which outputs +general, MLX has limited support for operations for which output *shapes* are dependent on input *data*. Other examples of these types of operations which MLX does not yet support include :func:`numpy.nonzero` and the single input version of :func:`numpy.where`.