mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-16 14:18:12 +08:00
awni's commit files
This commit is contained in:
1
docs/.gitignore
vendored
Normal file
1
docs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
src/python/_autosummary*/
|
36
docs/README.md
Normal file
36
docs/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
## Build the Docs
|
||||
|
||||
### Setup (do once)
|
||||
|
||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
||||
for example with `conda`:
|
||||
|
||||
```
|
||||
conda install sphinx
|
||||
pip install sphinx-rtd-theme
|
||||
```
|
||||
|
||||
### Build
|
||||
|
||||
Build the docs from `mlx/docs/`
|
||||
|
||||
```
|
||||
make html
|
||||
```
|
||||
|
||||
View the docs by running a server in `mlx/docs/build/html/`:
|
||||
|
||||
```
|
||||
python -m http.server <port>
|
||||
```
|
||||
|
||||
and point your browser to `http://localhost:<port>`.
|
||||
|
||||
### Push to Github Pages
|
||||
|
||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||
the docs. Then force add the `build/html` directory:
|
||||
|
||||
`git add -f build/html`
|
||||
|
||||
Commit and push the changes to the `gh-pages` branch.
|
20
docs/src/_templates/optimizers-template.rst
Normal file
20
docs/src/_templates/optimizers-template.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
44
docs/src/conf.py
Normal file
44
docs/src/conf.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = "0.0.0"
|
||||
release = "0.0.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.napoleon",
|
||||
]
|
||||
|
||||
python_use_unqualified_type_names = True
|
||||
autosummary_generate = True
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
"https://numpy.org/doc/stable/": None,
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
highlight_language = "python"
|
||||
pygments_style = "sphinx"
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
htmlhelp_basename = "mlx_doc"
|
382
docs/src/examples/llama-inference.rst
Normal file
382
docs/src/examples/llama-inference.rst
Normal file
@@ -0,0 +1,382 @@
|
||||
LLM inference
|
||||
==============
|
||||
|
||||
MLX enables efficient inference of large-ish transformers on Apple silicon
|
||||
without compromising on ease of use. In this example we will create an
|
||||
inference script for the Llama family of transformer models in which the model
|
||||
is defined in less than 200 lines of python.
|
||||
|
||||
Implementing the model
|
||||
----------------------
|
||||
|
||||
We will use the neural network building blocks defined in the :mod:`mlx.nn`
|
||||
module to concisely define the model architecture.
|
||||
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
||||
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
|
||||
:class:`mlx.nn.RoPE` for the positional encoding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
Encoder layer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
The other component of the Llama model is the encoder layer which uses RMS
|
||||
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
|
||||
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
Full model
|
||||
^^^^^^^^^^
|
||||
|
||||
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
|
||||
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.out_proj(x)
|
||||
|
||||
Note that in the implementation above we use a simple list to hold the encoder
|
||||
layers but using ``model.parameters()`` will still consider these layers.
|
||||
|
||||
Generation
|
||||
^^^^^^^^^^^
|
||||
|
||||
Our ``Llama`` module can be used for training but not inference as the
|
||||
``__call__`` method above processes one input, completely ignores the cache and
|
||||
performs no sampling whatsoever. In the rest of this subsection, we will
|
||||
implement the inference function as a python generator that processes the
|
||||
prompt and then autoregressively yields tokens one at a time.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
...
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same way as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
cache.append(c) # <--- we store the per layer cache in a
|
||||
# simple python list
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
|
||||
# that generate the next token
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
yield y
|
||||
|
||||
Putting it all together
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We now have everything we need to create a Llama model and sample tokens from
|
||||
it. In the following code, we randomly initialize a small Llama model, process
|
||||
6 tokens of prompt and generate 10 tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
|
||||
|
||||
# Since MLX is lazily evaluated nothing has actually been materialized yet.
|
||||
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
|
||||
# code above would still run. Let's actually materialize the model.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
|
||||
# have a batch dimension even
|
||||
# though it is 1 in this case
|
||||
|
||||
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
|
||||
|
||||
# Since we haven't evaluated anything, nothing is computed yet. The list
|
||||
# `generated` contains the arrays that hold the computation graph for the
|
||||
# full processing of the prompt and the generation of 10 tokens.
|
||||
#
|
||||
# We can evaluate them one at a time, or all together. Concatenate them or
|
||||
# print them. They would all result in very similar runtimes and give exactly
|
||||
# the same results.
|
||||
mx.eval(generated)
|
||||
|
||||
Converting the weights
|
||||
----------------------
|
||||
|
||||
This section assumes that you have access to the original Llama weights and the
|
||||
SentencePiece model that comes with them. We will write a small script to
|
||||
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
|
||||
that can be loaded directly by MLX.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
||||
|
||||
|
||||
Weight loading and benchmarking
|
||||
-------------------------------
|
||||
|
||||
After converting the weights to be compatible to our implementation, all that is
|
||||
left is to load them from disk and we can finally use the LLM to generate text.
|
||||
We can load numpy format files using the :func:`mlx.core.load` operation.
|
||||
|
||||
To create a parameter dictionary from the key/value representation of NPZ files
|
||||
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
model.update(tree_unflatten(list(mx.load(weight_file).items())))
|
||||
|
||||
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
|
||||
like ``layers.2.attention.query_proj.weight`` and will transform them to
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
|
||||
|
||||
which can then be used to update the model. Note that the method above incurs
|
||||
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||
will be replaced in the future with direct loading to MLX.
|
||||
|
||||
You can download the full example code in `mlx-examples <code>`_. Assuming, the
|
||||
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||
directory we can play around with our inference script as follows (the timings
|
||||
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python convert.py weights.pth llama-7B.mlx.npz
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
|
||||
------
|
||||
[INFO] Prompt processing: 0.437 s
|
||||
[INFO] Full generation: 4.330 s
|
||||
|
||||
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
|
||||
of those are spent processing the prompt. This amounts to a little over **39 ms
|
||||
per token**.
|
||||
|
||||
By running with a much bigger prompt we can see that the per token generation
|
||||
time as well as the prompt processing time remains almost constant.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
|
||||
------
|
||||
[INFO] Prompt processing: 0.579 s
|
||||
[INFO] Full generation: 4.690 s
|
||||
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.628 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
|
||||
------
|
||||
[INFO] Prompt processing: 0.633 s
|
||||
[INFO] Full generation: 21.475 s
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
preprint arXiv:2104.09864.
|
||||
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
|
||||
Advances in Neural Information Processing Systems, 32.
|
||||
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
|
||||
arXiv:2002.05202.
|
49
docs/src/index.rst
Normal file
49
docs/src/index.rst
Normal file
@@ -0,0 +1,49 @@
|
||||
MLX
|
||||
===
|
||||
|
||||
.. toctree::
|
||||
:caption: Install
|
||||
:maxdepth: 1
|
||||
|
||||
install
|
||||
|
||||
.. toctree::
|
||||
:caption: Usage
|
||||
:maxdepth: 1
|
||||
|
||||
quick_start
|
||||
using_streams
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
:maxdepth: 1
|
||||
|
||||
examples/linear_regression
|
||||
examples/mlp
|
||||
examples/llama-inference
|
||||
|
||||
.. toctree::
|
||||
:caption: Further Reading
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
|
||||
.. toctree::
|
||||
:caption: Python API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fft
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
:caption: C++ API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
cpp/ops
|
102
docs/src/install.rst
Normal file
102
docs/src/install.rst
Normal file
@@ -0,0 +1,102 @@
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
Install from PyPI
|
||||
-----------------
|
||||
|
||||
MLX is available at Apple's internal PyPI repository. All you have to do to use
|
||||
MLX with your own Apple silicon computer is
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install apple-mlx -i https://pypi.apple.com/simple
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
|
||||
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
|
||||
Then simply build and install it using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
|
||||
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
by cloning MLX from `its GitHub repo
|
||||
<https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Create a build directory and run CMake and make:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make test
|
||||
|
||||
Install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make install
|
||||
|
||||
Note that the built ``mlx.metallib`` file should be either at the same
|
||||
directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
* - Option
|
||||
- Default
|
||||
* - MLX_BUILD_TESTS
|
||||
- ON
|
||||
* - MLX_BUILD_EXAMPLES
|
||||
- OFF
|
||||
* - MLX_BUILD_BENCHMARKS
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
22
docs/src/python/fft.rst
Normal file
22
docs/src/python/fft.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _fft:
|
||||
|
||||
FFT
|
||||
===
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
fft
|
||||
ifft
|
||||
fft2
|
||||
ifft2
|
||||
fftn
|
||||
ifftn
|
||||
rfft
|
||||
irfft
|
||||
rfft2
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
172
docs/src/python/nn.rst
Normal file
172
docs/src/python/nn.rst
Normal file
@@ -0,0 +1,172 @@
|
||||
.. _nn:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Neural Networks
|
||||
===============
|
||||
|
||||
Writing arbitrarily complex neural networks in MLX can be done using only
|
||||
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
||||
user to write again and again the same simple neural network operations as well
|
||||
as handle all the parameter state and initialization manually and explicitly.
|
||||
|
||||
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
||||
composing neural network layers, initializing their parameters, freezing them
|
||||
for finetuning and more.
|
||||
|
||||
Quick Start with Neural Networks
|
||||
---------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dims: int, out_dims: int):
|
||||
super().__init__()
|
||||
|
||||
self.layers = [
|
||||
nn.Linear(in_dims, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, out_dims),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for i, l in enumerate(self.layers):
|
||||
x = mx.maximum(x, 0) if i > 0 else x
|
||||
x = l(x)
|
||||
return x
|
||||
|
||||
# The model is created with all its parameters but nothing is initialized
|
||||
# yet because MLX is lazily evaluated
|
||||
mlp = MLP(2, 10)
|
||||
|
||||
# We can access its parameters by calling mlp.parameters()
|
||||
params = mlp.parameters()
|
||||
print(params["layers"][0]["weight"].shape)
|
||||
|
||||
# Printing a parameter will cause it to be evaluated and thus initialized
|
||||
print(params["layers"][0])
|
||||
|
||||
# We can also force evaluate all parameters to initialize the model
|
||||
mx.eval(mlp.parameters())
|
||||
|
||||
# A simple loss function.
|
||||
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
||||
# it from the local scope. It could be a positional argument or a
|
||||
# keyword argument.
|
||||
def l2_loss(x, y):
|
||||
y_hat = mlp(x)
|
||||
return (y_hat - y).square().mean()
|
||||
|
||||
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
||||
# gradient with respect to `mlp.trainable_parameters()`
|
||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||
|
||||
|
||||
.. _module_class:
|
||||
|
||||
The Module Class
|
||||
----------------
|
||||
|
||||
The workhorse of any neural network library is the :class:`Module` class. In
|
||||
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
||||
:class:`Module` instances. Its main function is to provide a way to
|
||||
recursively **access** and **update** its parameters and those of its
|
||||
submodules.
|
||||
|
||||
Parameters
|
||||
^^^^^^^^^^
|
||||
|
||||
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
||||
name should not start with ``_``). It can be arbitrarily nested in other
|
||||
:class:`Module` instances or lists and dictionaries.
|
||||
|
||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.
|
||||
|
||||
A :class:`Module` can also keep track of "frozen" parameters.
|
||||
:meth:`Module.trainable_parameters` returns only the subset of
|
||||
:meth:`Module.parameters` that is not frozen. When using
|
||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
||||
trainable parameters.
|
||||
|
||||
Updating the parameters
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module's parameters. This action is
|
||||
performed by :meth:`Module.update`.
|
||||
|
||||
Value and grad
|
||||
--------------
|
||||
|
||||
Using a :class:`Module` does not preclude using MLX's high order function
|
||||
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
||||
these function transformations assume pure functions, namely the parameters
|
||||
should be passed as an argument to the function being transformed.
|
||||
|
||||
There is an easy pattern to achieve that with MLX modules
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = ...
|
||||
|
||||
def f(params, other_inputs):
|
||||
model.update(params) # <---- Necessary to make the model use the passed parameters
|
||||
return model(other_inputs)
|
||||
|
||||
f(model.trainable_parameters(), mx.zeros((10,)))
|
||||
|
||||
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
||||
computes the gradients with respect to the trainable parameters of the model.
|
||||
|
||||
In detail:
|
||||
|
||||
- it wraps the passed function with a function that calls :meth:`Module.update`
|
||||
to make sure the model is using the provided parameters.
|
||||
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
||||
that also computes the gradients with respect to the passed parameters.
|
||||
- it wraps the returned function with a function that passes the trainable
|
||||
parameters as the first argument to the function returned by
|
||||
:meth:`mlx.core.value_and_grad`
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
|
||||
Neural Network Layers
|
||||
---------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
GELU
|
||||
SiLU
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
silu
|
7
docs/src/python/nn/module.rst
Normal file
7
docs/src/python/nn/module.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
mlx.nn.Module
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
:members:
|
41
docs/src/python/optimizers.rst
Normal file
41
docs/src/python/optimizers.rst
Normal file
@@ -0,0 +1,41 @@
|
||||
.. _optimizers:
|
||||
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
|
||||
:mod:`mlx.core` functions. A typical example involves calling
|
||||
:meth:`Optimizer.update` to update a model's parameters based on the loss
|
||||
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
|
||||
model's parameters and the **optimizer state**.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Create the gradient function and the optimizer
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
|
||||
# Update the model with the gradients. So far no computation has happened.
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
Adam
|
45
docs/src/python/random.rst
Normal file
45
docs/src/python/random.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
.. _random:
|
||||
|
||||
Random
|
||||
======
|
||||
|
||||
Random sampling functions in MLX use an implicit global PRNG state by default.
|
||||
However, all function take an optional ``key`` keyword argument for when more
|
||||
fine-grained control or explicit state management is needed.
|
||||
|
||||
For example, you can generate random numbers with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for _ in range(3):
|
||||
print(mx.random.uniform())
|
||||
|
||||
which will print a sequence of unique pseudo random numbers. Alternatively you
|
||||
can explicitly set the key:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
key = mx.random.key(0)
|
||||
for _ in range(3):
|
||||
print(mx.random.uniform(key=key))
|
||||
|
||||
which will yield the same pseudo random number at each iteration.
|
||||
|
||||
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
||||
we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
|
||||
.. currentmodule:: mlx.core.random
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
seed
|
||||
key
|
||||
split
|
||||
bernoulli
|
||||
categorical
|
||||
gumbel
|
||||
normal
|
||||
randint
|
||||
uniform
|
||||
truncated_normal
|
93
docs/src/quick_start.rst
Normal file
93
docs/src/quick_start.rst
Normal file
@@ -0,0 +1,93 @@
|
||||
Quick Start Guide
|
||||
=================
|
||||
|
||||
MLX is a NumPy-like array framework designed for efficient and flexible
|
||||
machine learning on Apple silicon. The Python API closely follows NumPy with
|
||||
a few exceptions. MLX also has a fully featured C++ API which closely follows
|
||||
the Python API.
|
||||
|
||||
The main differences between MLX and NumPy are:
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Multi-device**: Operations can run on any of the suppoorted devices (CPU,
|
||||
GPU, ...)
|
||||
|
||||
The design of MLX is strongly inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
are the CPU and GPU.
|
||||
|
||||
Basics
|
||||
------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Import ``mlx.core`` and make an :class:`array`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> import mlx.core as mx
|
||||
>> a = mx.array([1, 2, 3, 4])
|
||||
>> a.shape
|
||||
[4]
|
||||
>> a.dtype
|
||||
int32
|
||||
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
>> b.dtype
|
||||
float32
|
||||
|
||||
Operations in MLX are lazy. The outputs of MLX operations are not computed
|
||||
until they are needed. To force an array to be evaluated use
|
||||
:func:`eval`. Arrays will automatically be evaluated in a few cases. For
|
||||
example, inspecting a scalar with :meth:`array.item`, printing an array,
|
||||
or converting an array from :class:`array` to :class:`numpy.ndarray` all
|
||||
automatically evaluate the array.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> c = a + b # c not yet evaluated
|
||||
>> mx.eval(c) # evaluates c
|
||||
>> c = a + b
|
||||
>> print(c) # Also evaluates c
|
||||
array([2, 4, 6, 8], dtype=float32)
|
||||
>> c = a + b
|
||||
>> import numpy as np
|
||||
>> np.array(c) # Also evaluates c
|
||||
array([2., 4., 6., 8.], dtype=float32)
|
||||
|
||||
Function and Graph Transformations
|
||||
----------------------------------
|
||||
|
||||
MLX has standard function transformations like :func:`grad` and :func:`vmap`.
|
||||
Transformations can be composed arbitrarily. For example
|
||||
``grad(vmap(grad(fn)))`` (or any other composition) is allowed.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> x = mx.array(0.0)
|
||||
>> mx.sin(x)
|
||||
array(0, dtype=float32)
|
||||
>> mx.grad(mx.sin)(x)
|
||||
array(1, dtype=float32)
|
||||
>> mx.grad(mx.grad(mx.sin))(x)
|
||||
array(-0, dtype=float32)
|
||||
|
||||
Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
||||
and :func:`jvp` for Jacobian-vector products.
|
||||
|
||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||
gradient with respect to the function's input.
|
||||
|
||||
|
||||
Devices and Streams
|
||||
-------------------
|
||||
|
||||
|
||||
|
16
docs/src/using_streams.rst
Normal file
16
docs/src/using_streams.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Using Streams
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Specifying the :obj:`Stream`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
All operations (including random number generation) take an optional
|
||||
keyword argument ``stream``. The ``stream`` kwarg specifies which
|
||||
:obj:`Stream` the operation should run on. If the stream is unspecified then
|
||||
the operation is run on the default stream of the default device:
|
||||
``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also
|
||||
be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is
|
||||
run on the default stream of the provided device
|
||||
``mx.default_stream(my_device)``.
|
Reference in New Issue
Block a user