From c92a134b0d8f0e9fbc90b49dae7edbad8da18d58 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 10 Jan 2024 14:04:12 -0800 Subject: [PATCH] more docs (#421) * more docs * fix link * nits + comments --- docs/src/index.rst | 5 +- docs/src/usage/indexing.rst | 123 ++++++++++++++++++++++ docs/src/usage/lazy_evaluation.rst | 144 ++++++++++++++++++++++++++ docs/src/usage/numpy.rst | 5 + docs/src/usage/quick_start.rst | 3 + docs/src/usage/saving_and_loading.rst | 81 +++++++++++++++ 6 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 docs/src/usage/indexing.rst create mode 100644 docs/src/usage/lazy_evaluation.rst create mode 100644 docs/src/usage/saving_and_loading.rst diff --git a/docs/src/index.rst b/docs/src/index.rst index f1fe468ca..cd3db34b3 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -36,9 +36,12 @@ are the CPU and GPU. :maxdepth: 1 usage/quick_start + usage/lazy_evaluation usage/unified_memory - usage/using_streams + usage/indexing + usage/saving_and_loading usage/numpy + usage/using_streams .. toctree:: :caption: Examples diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst new file mode 100644 index 000000000..458541923 --- /dev/null +++ b/docs/src/usage/indexing.rst @@ -0,0 +1,123 @@ +.. _indexing: + +Indexing Arrays +=============== + +.. currentmodule:: mlx.core + +For the most part, indexing an MLX :obj:`array` works the same as indexing a +NumPy :obj:`numpy.ndarray`. See the `NumPy documentation +`_ for more details on +how that works. + +For example, you can use regular integers and slices (:obj:`slice`) to index arrays: + +.. code-block:: shell + + >>> arr = mx.arange(10) + >>> arr[3] + array(3, dtype=int32) + >>> arr[-2] # negative indexing works + array(8, dtype=int32) + >>> arr[2:8:2] # start, stop, stride + array([2, 4, 6], dtype=int32) + +For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy: + +.. code-block:: shell + + >>> arr = mx.arange(8).reshape(2, 2, 2) + >>> arr[:, :, 0] + array(3, dtype=int32) + array([[0, 2], + [4, 6]], dtype=int32 + >>> arr[..., 0] + array([[0, 2], + [4, 6]], dtype=int32 + +You can index with ``None`` to create a new axis: + +.. code-block:: shell + + >>> arr = mx.arange(8) + >>> arr.shape + [8] + >>> arr[None].shape + [1, 8] + + +You can also use an :obj:`array` to index another :obj:`array`: + +.. code-block:: shell + + >>> arr = mx.arange(10) + >>> idx = mx.array([5, 7]) + >>> arr[idx] + array([5, 7], dtype=int32) + +Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices +works just as in NumPy. + +Other functions which may be useful for indexing arrays are :func:`take` and +:func:`take_along_axis`. + +Differences from NumPy +---------------------- + +.. Note:: + + MLX indexing is different from NumPy indexing in two important ways: + + * Indexing does not perform bounds checking. Indexing out of bounds is + undefined behavior. + * Boolean mask based indexing is not yet supported. + +The reason for the lack of bounds checking is that exceptions cannot propagate +from the GPU. Performing bounds checking for array indices before launching the +kernel would be extremely inefficient. + +Indexing with boolean masks is something that MLX may support in the future. In +general, MLX has limited support for operations for which outputs +*shapes* are dependent on input *data*. Other examples of these types of +operations which MLX does not yet support include :func:`numpy.nonzero` and the +single input version of :func:`numpy.where`. + +In Place Updates +---------------- + +In place updates to indexed arrays are possible in MLX. For example: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[2] = 0 + >>> a + array([1, 2, 0], dtype=int32) + +Just as in NumPy, in place updates will be reflected in all references to the +same array: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> b = a + >>> b[2] = 0 + >>> b + array([1, 2, 0], dtype=int32) + >>> a + array([1, 2, 0], dtype=int32) + +Transformations of functions which use in-place updates are allowed and work as +expected. For example: + +.. code-block:: python + + def fun(x, idx): + x[idx] = 2.0 + return x.sum() + + dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) + print(dfdx) # Prints: array([1, 0, 1], dtype=float32) + +In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` +and ones elsewhere. diff --git a/docs/src/usage/lazy_evaluation.rst b/docs/src/usage/lazy_evaluation.rst new file mode 100644 index 000000000..4f14ceeed --- /dev/null +++ b/docs/src/usage/lazy_evaluation.rst @@ -0,0 +1,144 @@ +.. _lazy eval: + +Lazy Evaluation +=============== + +.. currentmodule:: mlx.core + +Why Lazy Evaluation +------------------- + +When you perform operations in MLX, no computation actually happens. Instead a +compute graph is recorded. The actual computation only happens if an +:func:`eval` is performed. + +MLX uses lazy evaluation because it has some nice features, some of which we +describe below. + +Transforming Compute Graphs +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Lazy evaluation let's us record a compute graph without actually doing any +computations. This is useful for function transformations like :func:`grad` and +:func:`vmap` and graph optimizations like :func:`simplify`. + +Currently, MLX does not compile and rerun compute graphs. They are all +generated dynamically. However, lazy evaluation makes it much easier to +integrate compilation for future performance enhancements. + +Only Compute What You Use +^^^^^^^^^^^^^^^^^^^^^^^^^ + +In MLX you do not need to worry as much about computing outputs that are never +used. For example: + +.. code-block:: python + + def fun(x): + a = fun1(x) + b = expensive_fun(a) + return a, b + + y, _ = fun(x) + +Here, we never actually compute the output of ``expensive_fun``. Use this +pattern with care though, as the graph of ``expensive_fun`` is still built, and +that has some cost associated to it. + +Similarly, lazy evaluation can be beneficial for saving memory while keeping +code simple. Say you have a very large model ``Model`` derived from +:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``. +Typically, this will initialize all of the weights as ``float32``, but the +initialization does not actually compute anything until you perform an +:func:`eval`. If you update the model with ``float16`` weights, your maximum +consumed memory will be half that required if eager computation was used +instead. + +This pattern is simple to do in MLX thanks to lazy computation: + +.. code-block:: python + + model = Model() # no memory used yet + model.load_weights("weights_fp16.safetensors") + +When to Evaluate +---------------- + +A common question is when to use :func:`eval`. The trade-off is between +letting graphs get too large and not batching enough useful work. + +For example: + +.. code-block:: python + + for _ in range(100): + a = a + b + mx.eval(a) + b = b * 2 + mx.eval(b) + +This is a bad idea because there is some fixed overhead with each graph +evaluation. On the other hand, there is some slight overhead which grows with +the compute graph size, so extremely large graphs (while computationally +correct) can be costly. + +Luckily, a wide range of compute graph sizes work pretty well with MLX: +anything from a few tens of operations to many thousands of operations per +evaluation should be okay. + +Most numerical computations have an iterative outer loop (e.g. the iteration in +stochastic gradient descent). A natural and usually efficient place to use +:func:`eval` is at each iteration of this outer loop. + +Here is a concrete example: + +.. code-block:: python + + for batch in dataset: + + # Nothing has been evaluated yet + loss, grad = value_and_grad_fn(model, batch) + + # Still nothing has been evaluated + optimizer.update(model, grad) + + # Evaluate the loss and the new parameters which will + # run the full gradient computation and optimizer update + mx.eval(loss, model.parameters()) + + +An important behavior to be aware of is when the graph will be implicitly +evaluated. Anytime you ``print`` an array, convert it to an +:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, +the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX +saving functions) will also evaluate the array. + + +Calling :func:`array.item` on a scalar array will also evaluate it. In the +example above, printing the loss (``print(loss)``) or adding the loss scalar to +a list (``losses.append(loss.item())``) would cause a graph evaluation. If +these lines are before ``mx.eval(loss, model.parameters())`` then this +will be a partial evaluation, computing only the forward pass. + +Also, calling :func:`eval` on an array or set of arrays multiple times is +perfectly fine. This is effectively a no-op. + +.. warning:: + + Using scalar arrays for control-flow will cause an evaluation. + +Here is an example: + +.. code-block:: python + + def fun(x): + h, y = first_layer(x) + if y > 0: # An evaluation is done here! + z = second_layer_a(h) + else: + z = second_layer_b(h) + return z + +Using arrays for control flow should be done with care. The above example works +and can even be used with gradient transformations. However, this can be very +inefficient if evaluations are done too frequently. diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index ef075ad0c..1ed801454 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -62,6 +62,11 @@ even though no in-place operations on MLX memory are executed. PyTorch ------- +.. warning:: + + PyTorch Support for :obj:`memoryview` is experimental and can break for + multi-dimensional arrays. Casting to NumPy first is advised for now. + PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. .. code-block:: python diff --git a/docs/src/usage/quick_start.rst b/docs/src/usage/quick_start.rst index 9ffd29ae6..251f5344c 100644 --- a/docs/src/usage/quick_start.rst +++ b/docs/src/usage/quick_start.rst @@ -40,6 +40,9 @@ automatically evaluate the array. >> np.array(c) # Also evaluates c array([2., 4., 6., 8.], dtype=float32) + +See the page on :ref:`Lazy Evaluation ` for more details. + Function and Graph Transformations ---------------------------------- diff --git a/docs/src/usage/saving_and_loading.rst b/docs/src/usage/saving_and_loading.rst new file mode 100644 index 000000000..895ca342f --- /dev/null +++ b/docs/src/usage/saving_and_loading.rst @@ -0,0 +1,81 @@ +.. _saving_and_loading: + +Saving and Loading Arrays +========================= + +.. currentmodule:: mlx.core + +MLX supports multiple array serialization formats. + +.. list-table:: Serialization Formats + :widths: 20 8 25 25 + :header-rows: 1 + + * - Format + - Extension + - Function + - Notes + * - NumPy + - ``.npy`` + - :func:`save` + - Single arrays only + * - NumPy archive + - ``.npz`` + - :func:`savez` and :func:`savez_compressed` + - Multiple arrays + * - Safetensors + - ``.safetensors`` + - :func:`save_safetensors` + - Multiple arrays + * - GGUF + - ``.gguf`` + - :func:`save_gguf` + - Multiple arrays + +The :func:`load` function will load any of the supported serialization +formats. It determines the format from the extensions. The output of +:func:`load` depends on the format. + +Here's an example of saving a single array to a file: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> mx.save("array", a) + +The array ``a`` will be saved in the file ``array.npy`` (notice the extension +is automatically added). Including the extension is optional; if it is missing +it will be added. You can load the array with: + +.. code-block:: shell + + >>> mx.load("array.npy", a) + array([1], dtype=float32) + +Here's an example of saving several arrays to a single file: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> b = mx.array([2.0]) + >>> mx.savez("arrays", a, b=b) + +For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays +as arguments. If the keywords are missing, then default names will be +provided. This can be loaded with: + +.. code-block:: shell + + >>> mx.load("arrays.npz") + {'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)} + +In this case :func:`load` returns a dictionary of names to arrays. + +The functions :func:`save_safetensors` and :func:`save_gguf` are similar to +:func:`savez`, but they take as input a :obj:`dict` of string names to arrays: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> b = mx.array([2.0]) + >>> mx.save_safetensors("arrays", {"a": a, "b": b})