mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
3b4f066dac
commit
c92a134b0d
@ -36,9 +36,12 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
usage/quick_start
|
usage/quick_start
|
||||||
|
usage/lazy_evaluation
|
||||||
usage/unified_memory
|
usage/unified_memory
|
||||||
usage/using_streams
|
usage/indexing
|
||||||
|
usage/saving_and_loading
|
||||||
usage/numpy
|
usage/numpy
|
||||||
|
usage/using_streams
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Examples
|
:caption: Examples
|
||||||
|
123
docs/src/usage/indexing.rst
Normal file
123
docs/src/usage/indexing.rst
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
.. _indexing:
|
||||||
|
|
||||||
|
Indexing Arrays
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
||||||
|
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
||||||
|
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
||||||
|
how that works.
|
||||||
|
|
||||||
|
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> arr[3]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
>>> arr[-2] # negative indexing works
|
||||||
|
array(8, dtype=int32)
|
||||||
|
>>> arr[2:8:2] # start, stop, stride
|
||||||
|
array([2, 4, 6], dtype=int32)
|
||||||
|
|
||||||
|
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
||||||
|
>>> arr[:, :, 0]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
>>> arr[..., 0]
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
|
||||||
|
You can index with ``None`` to create a new axis:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8)
|
||||||
|
>>> arr.shape
|
||||||
|
[8]
|
||||||
|
>>> arr[None].shape
|
||||||
|
[1, 8]
|
||||||
|
|
||||||
|
|
||||||
|
You can also use an :obj:`array` to index another :obj:`array`:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> idx = mx.array([5, 7])
|
||||||
|
>>> arr[idx]
|
||||||
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
|
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
||||||
|
works just as in NumPy.
|
||||||
|
|
||||||
|
Other functions which may be useful for indexing arrays are :func:`take` and
|
||||||
|
:func:`take_along_axis`.
|
||||||
|
|
||||||
|
Differences from NumPy
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
.. Note::
|
||||||
|
|
||||||
|
MLX indexing is different from NumPy indexing in two important ways:
|
||||||
|
|
||||||
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
|
undefined behavior.
|
||||||
|
* Boolean mask based indexing is not yet supported.
|
||||||
|
|
||||||
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
|
general, MLX has limited support for operations for which outputs
|
||||||
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
|
In Place Updates
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[2] = 0
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Just as in NumPy, in place updates will be reflected in all references to the
|
||||||
|
same array:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
|
expected. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, idx):
|
||||||
|
x[idx] = 2.0
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||||
|
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
||||||
|
|
||||||
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
|
and ones elsewhere.
|
144
docs/src/usage/lazy_evaluation.rst
Normal file
144
docs/src/usage/lazy_evaluation.rst
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
.. _lazy eval:
|
||||||
|
|
||||||
|
Lazy Evaluation
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Why Lazy Evaluation
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
When you perform operations in MLX, no computation actually happens. Instead a
|
||||||
|
compute graph is recorded. The actual computation only happens if an
|
||||||
|
:func:`eval` is performed.
|
||||||
|
|
||||||
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
|
describe below.
|
||||||
|
|
||||||
|
Transforming Compute Graphs
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Lazy evaluation let's us record a compute graph without actually doing any
|
||||||
|
computations. This is useful for function transformations like :func:`grad` and
|
||||||
|
:func:`vmap` and graph optimizations like :func:`simplify`.
|
||||||
|
|
||||||
|
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||||
|
generated dynamically. However, lazy evaluation makes it much easier to
|
||||||
|
integrate compilation for future performance enhancements.
|
||||||
|
|
||||||
|
Only Compute What You Use
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
In MLX you do not need to worry as much about computing outputs that are never
|
||||||
|
used. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
a = fun1(x)
|
||||||
|
b = expensive_fun(a)
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
y, _ = fun(x)
|
||||||
|
|
||||||
|
Here, we never actually compute the output of ``expensive_fun``. Use this
|
||||||
|
pattern with care though, as the graph of ``expensive_fun`` is still built, and
|
||||||
|
that has some cost associated to it.
|
||||||
|
|
||||||
|
Similarly, lazy evaluation can be beneficial for saving memory while keeping
|
||||||
|
code simple. Say you have a very large model ``Model`` derived from
|
||||||
|
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
|
||||||
|
Typically, this will initialize all of the weights as ``float32``, but the
|
||||||
|
initialization does not actually compute anything until you perform an
|
||||||
|
:func:`eval`. If you update the model with ``float16`` weights, your maximum
|
||||||
|
consumed memory will be half that required if eager computation was used
|
||||||
|
instead.
|
||||||
|
|
||||||
|
This pattern is simple to do in MLX thanks to lazy computation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = Model() # no memory used yet
|
||||||
|
model.load_weights("weights_fp16.safetensors")
|
||||||
|
|
||||||
|
When to Evaluate
|
||||||
|
----------------
|
||||||
|
|
||||||
|
A common question is when to use :func:`eval`. The trade-off is between
|
||||||
|
letting graphs get too large and not batching enough useful work.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
a = a + b
|
||||||
|
mx.eval(a)
|
||||||
|
b = b * 2
|
||||||
|
mx.eval(b)
|
||||||
|
|
||||||
|
This is a bad idea because there is some fixed overhead with each graph
|
||||||
|
evaluation. On the other hand, there is some slight overhead which grows with
|
||||||
|
the compute graph size, so extremely large graphs (while computationally
|
||||||
|
correct) can be costly.
|
||||||
|
|
||||||
|
Luckily, a wide range of compute graph sizes work pretty well with MLX:
|
||||||
|
anything from a few tens of operations to many thousands of operations per
|
||||||
|
evaluation should be okay.
|
||||||
|
|
||||||
|
Most numerical computations have an iterative outer loop (e.g. the iteration in
|
||||||
|
stochastic gradient descent). A natural and usually efficient place to use
|
||||||
|
:func:`eval` is at each iteration of this outer loop.
|
||||||
|
|
||||||
|
Here is a concrete example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for batch in dataset:
|
||||||
|
|
||||||
|
# Nothing has been evaluated yet
|
||||||
|
loss, grad = value_and_grad_fn(model, batch)
|
||||||
|
|
||||||
|
# Still nothing has been evaluated
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
# Evaluate the loss and the new parameters which will
|
||||||
|
# run the full gradient computation and optimizer update
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
|
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||||
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|
||||||
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
Also, calling :func:`eval` on an array or set of arrays multiple times is
|
||||||
|
perfectly fine. This is effectively a no-op.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using scalar arrays for control-flow will cause an evaluation.
|
||||||
|
|
||||||
|
Here is an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
h, y = first_layer(x)
|
||||||
|
if y > 0: # An evaluation is done here!
|
||||||
|
z = second_layer_a(h)
|
||||||
|
else:
|
||||||
|
z = second_layer_b(h)
|
||||||
|
return z
|
||||||
|
|
||||||
|
Using arrays for control flow should be done with care. The above example works
|
||||||
|
and can even be used with gradient transformations. However, this can be very
|
||||||
|
inefficient if evaluations are done too frequently.
|
@ -62,6 +62,11 @@ even though no in-place operations on MLX memory are executed.
|
|||||||
PyTorch
|
PyTorch
|
||||||
-------
|
-------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
|
||||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -40,6 +40,9 @@ automatically evaluate the array.
|
|||||||
>> np.array(c) # Also evaluates c
|
>> np.array(c) # Also evaluates c
|
||||||
array([2., 4., 6., 8.], dtype=float32)
|
array([2., 4., 6., 8.], dtype=float32)
|
||||||
|
|
||||||
|
|
||||||
|
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
|
||||||
|
|
||||||
Function and Graph Transformations
|
Function and Graph Transformations
|
||||||
----------------------------------
|
----------------------------------
|
||||||
|
|
||||||
|
81
docs/src/usage/saving_and_loading.rst
Normal file
81
docs/src/usage/saving_and_loading.rst
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
.. _saving_and_loading:
|
||||||
|
|
||||||
|
Saving and Loading Arrays
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
|
.. list-table:: Serialization Formats
|
||||||
|
:widths: 20 8 25 25
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Format
|
||||||
|
- Extension
|
||||||
|
- Function
|
||||||
|
- Notes
|
||||||
|
* - NumPy
|
||||||
|
- ``.npy``
|
||||||
|
- :func:`save`
|
||||||
|
- Single arrays only
|
||||||
|
* - NumPy archive
|
||||||
|
- ``.npz``
|
||||||
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
|
- Multiple arrays
|
||||||
|
* - Safetensors
|
||||||
|
- ``.safetensors``
|
||||||
|
- :func:`save_safetensors`
|
||||||
|
- Multiple arrays
|
||||||
|
* - GGUF
|
||||||
|
- ``.gguf``
|
||||||
|
- :func:`save_gguf`
|
||||||
|
- Multiple arrays
|
||||||
|
|
||||||
|
The :func:`load` function will load any of the supported serialization
|
||||||
|
formats. It determines the format from the extensions. The output of
|
||||||
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> mx.save("array", a)
|
||||||
|
|
||||||
|
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
|
||||||
|
is automatically added). Including the extension is optional; if it is missing
|
||||||
|
it will be added. You can load the array with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("array.npy", a)
|
||||||
|
array([1], dtype=float32)
|
||||||
|
|
||||||
|
Here's an example of saving several arrays to a single file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.savez("arrays", a, b=b)
|
||||||
|
|
||||||
|
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
|
||||||
|
as arguments. If the keywords are missing, then default names will be
|
||||||
|
provided. This can be loaded with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("arrays.npz")
|
||||||
|
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
|
||||||
|
|
||||||
|
In this case :func:`load` returns a dictionary of names to arrays.
|
||||||
|
|
||||||
|
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
|
||||||
|
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
|
Loading…
Reference in New Issue
Block a user