awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

22
docs/src/python/fft.rst Normal file
View File

@@ -0,0 +1,22 @@
.. _fft:
FFT
===
.. currentmodule:: mlx.core.fft
.. autosummary::
:toctree: _autosummary
fft
ifft
fft2
ifft2
fftn
ifftn
rfft
irfft
rfft2
irfft2
rfftn
irfftn

172
docs/src/python/nn.rst Normal file
View File

@@ -0,0 +1,172 @@
.. _nn:
.. currentmodule:: mlx.nn
Neural Networks
===============
Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks
---------------------------------
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
----------------
The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.
Parameters
^^^^^^^^^^
A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters.
:meth:`Module.trainable_parameters` returns only the subset of
:meth:`Module.parameters` that is not frozen. When using
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
trainable parameters.
Updating the parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
Value and grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
.. code-block:: python
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
- it wraps the passed function with a function that calls :meth:`Module.update`
to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
:meth:`mlx.core.value_and_grad`
.. autosummary::
:toctree: _autosummary
value_and_grad
Neural Network Layers
---------------------
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
Embedding
ReLU
GELU
SiLU
Linear
Conv1d
Conv2d
LayerNorm
RMSNorm
GroupNorm
RoPE
MultiHeadAttention
Sequential
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
gelu
gelu_approx
gelu_fast_approx
relu
silu

View File

@@ -0,0 +1,7 @@
mlx.nn.Module
=============
.. currentmodule:: mlx.nn
.. autoclass:: Module
:members:

View File

@@ -0,0 +1,41 @@
.. _optimizers:
Optimizers
==========
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
:mod:`mlx.core` functions. A typical example involves calling
:meth:`Optimizer.update` to update a model's parameters based on the loss
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
model's parameters and the **optimizer state**.
.. code-block:: python
# Create a model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Create the gradient function and the optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the model with the gradients. So far no computation has happened.
optimizer.update(model, grads)
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
Adam

View File

@@ -0,0 +1,45 @@
.. _random:
Random
======
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
.. code-block:: python
for _ in range(3):
print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:
.. code-block:: python
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
we use a splittable version of Threefry, which is a counter-based PRNG.
.. currentmodule:: mlx.core.random
.. autosummary::
:toctree: _autosummary
seed
key
split
bernoulli
categorical
gumbel
normal
randint
uniform
truncated_normal