mlx/docs/build/html/_sources/usage/numpy.rst

125 lines
3.3 KiB
ReStructuredText
Raw Normal View History

2024-01-06 21:41:48 +08:00
.. _numpy:
Conversion to NumPy and Other Frameworks
========================================
2024-10-26 04:23:45 +08:00
MLX array supports conversion between other frameworks with either:
2024-06-07 11:28:06 +08:00
2024-10-26 04:23:45 +08:00
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
2024-06-07 11:28:06 +08:00
2024-01-06 21:41:48 +08:00
Let's convert an array to NumPy and back.
.. code-block:: python
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
.. note::
2025-02-15 05:44:39 +08:00
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
buffer format string does not match the dtype V item size 0.``
2024-01-06 21:41:48 +08:00
2025-02-15 05:44:39 +08:00
By default, NumPy copies data to a new array. This can be prevented by creating
an array view:
2024-01-06 21:41:48 +08:00
.. code-block:: python
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
2025-02-15 05:44:39 +08:00
.. note::
NumPy arrays with type ``float64`` will be default converted to MLX arrays
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
2024-01-06 21:41:48 +08:00
2025-02-15 05:44:39 +08:00
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
2024-01-06 21:41:48 +08:00
Let's demonstrate this in an example:
.. code-block:: python
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
The function ``f`` indirectly modifies the array ``x`` through a memory view.
2025-02-15 05:44:39 +08:00
However, this modification is not reflected in the gradient, as seen in the
last line outputting ``1.0``, representing the gradient of the sum operation
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
gradient is incorporated. It's important to note that a similar issue arises
during array conversion and copying. For instance, a function defined as
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
2024-01-06 21:41:48 +08:00
even though no in-place operations on MLX memory are executed.
PyTorch
-------
2024-10-26 04:23:45 +08:00
.. warning::
2024-01-11 06:14:38 +08:00
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
2025-02-15 05:44:39 +08:00
PyTorch supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
2024-01-06 21:41:48 +08:00
.. code-block:: python
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
2025-02-15 05:44:39 +08:00
Conversion from PyTorch tensors back to arrays must be done via intermediate
NumPy arrays with ``numpy()``.
2024-01-06 21:41:48 +08:00
JAX
---
JAX fully supports the buffer protocol.
.. code-block:: python
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow
----------
2025-02-15 05:44:39 +08:00
TensorFlow supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
2024-01-06 21:41:48 +08:00
.. code-block:: python
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)