mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
awni's commit files
This commit is contained in:
22
docs/src/python/fft.rst
Normal file
22
docs/src/python/fft.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _fft:
|
||||
|
||||
FFT
|
||||
===
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
fft
|
||||
ifft
|
||||
fft2
|
||||
ifft2
|
||||
fftn
|
||||
ifftn
|
||||
rfft
|
||||
irfft
|
||||
rfft2
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
172
docs/src/python/nn.rst
Normal file
172
docs/src/python/nn.rst
Normal file
@@ -0,0 +1,172 @@
|
||||
.. _nn:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Neural Networks
|
||||
===============
|
||||
|
||||
Writing arbitrarily complex neural networks in MLX can be done using only
|
||||
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
||||
user to write again and again the same simple neural network operations as well
|
||||
as handle all the parameter state and initialization manually and explicitly.
|
||||
|
||||
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
||||
composing neural network layers, initializing their parameters, freezing them
|
||||
for finetuning and more.
|
||||
|
||||
Quick Start with Neural Networks
|
||||
---------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dims: int, out_dims: int):
|
||||
super().__init__()
|
||||
|
||||
self.layers = [
|
||||
nn.Linear(in_dims, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, out_dims),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for i, l in enumerate(self.layers):
|
||||
x = mx.maximum(x, 0) if i > 0 else x
|
||||
x = l(x)
|
||||
return x
|
||||
|
||||
# The model is created with all its parameters but nothing is initialized
|
||||
# yet because MLX is lazily evaluated
|
||||
mlp = MLP(2, 10)
|
||||
|
||||
# We can access its parameters by calling mlp.parameters()
|
||||
params = mlp.parameters()
|
||||
print(params["layers"][0]["weight"].shape)
|
||||
|
||||
# Printing a parameter will cause it to be evaluated and thus initialized
|
||||
print(params["layers"][0])
|
||||
|
||||
# We can also force evaluate all parameters to initialize the model
|
||||
mx.eval(mlp.parameters())
|
||||
|
||||
# A simple loss function.
|
||||
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
||||
# it from the local scope. It could be a positional argument or a
|
||||
# keyword argument.
|
||||
def l2_loss(x, y):
|
||||
y_hat = mlp(x)
|
||||
return (y_hat - y).square().mean()
|
||||
|
||||
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
||||
# gradient with respect to `mlp.trainable_parameters()`
|
||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||
|
||||
|
||||
.. _module_class:
|
||||
|
||||
The Module Class
|
||||
----------------
|
||||
|
||||
The workhorse of any neural network library is the :class:`Module` class. In
|
||||
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
||||
:class:`Module` instances. Its main function is to provide a way to
|
||||
recursively **access** and **update** its parameters and those of its
|
||||
submodules.
|
||||
|
||||
Parameters
|
||||
^^^^^^^^^^
|
||||
|
||||
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
||||
name should not start with ``_``). It can be arbitrarily nested in other
|
||||
:class:`Module` instances or lists and dictionaries.
|
||||
|
||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.
|
||||
|
||||
A :class:`Module` can also keep track of "frozen" parameters.
|
||||
:meth:`Module.trainable_parameters` returns only the subset of
|
||||
:meth:`Module.parameters` that is not frozen. When using
|
||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
||||
trainable parameters.
|
||||
|
||||
Updating the parameters
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module's parameters. This action is
|
||||
performed by :meth:`Module.update`.
|
||||
|
||||
Value and grad
|
||||
--------------
|
||||
|
||||
Using a :class:`Module` does not preclude using MLX's high order function
|
||||
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
||||
these function transformations assume pure functions, namely the parameters
|
||||
should be passed as an argument to the function being transformed.
|
||||
|
||||
There is an easy pattern to achieve that with MLX modules
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = ...
|
||||
|
||||
def f(params, other_inputs):
|
||||
model.update(params) # <---- Necessary to make the model use the passed parameters
|
||||
return model(other_inputs)
|
||||
|
||||
f(model.trainable_parameters(), mx.zeros((10,)))
|
||||
|
||||
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
||||
computes the gradients with respect to the trainable parameters of the model.
|
||||
|
||||
In detail:
|
||||
|
||||
- it wraps the passed function with a function that calls :meth:`Module.update`
|
||||
to make sure the model is using the provided parameters.
|
||||
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
||||
that also computes the gradients with respect to the passed parameters.
|
||||
- it wraps the returned function with a function that passes the trainable
|
||||
parameters as the first argument to the function returned by
|
||||
:meth:`mlx.core.value_and_grad`
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
|
||||
Neural Network Layers
|
||||
---------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
GELU
|
||||
SiLU
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
silu
|
7
docs/src/python/nn/module.rst
Normal file
7
docs/src/python/nn/module.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
mlx.nn.Module
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
:members:
|
41
docs/src/python/optimizers.rst
Normal file
41
docs/src/python/optimizers.rst
Normal file
@@ -0,0 +1,41 @@
|
||||
.. _optimizers:
|
||||
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
|
||||
:mod:`mlx.core` functions. A typical example involves calling
|
||||
:meth:`Optimizer.update` to update a model's parameters based on the loss
|
||||
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
|
||||
model's parameters and the **optimizer state**.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Create the gradient function and the optimizer
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
|
||||
# Update the model with the gradients. So far no computation has happened.
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
Adam
|
45
docs/src/python/random.rst
Normal file
45
docs/src/python/random.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
.. _random:
|
||||
|
||||
Random
|
||||
======
|
||||
|
||||
Random sampling functions in MLX use an implicit global PRNG state by default.
|
||||
However, all function take an optional ``key`` keyword argument for when more
|
||||
fine-grained control or explicit state management is needed.
|
||||
|
||||
For example, you can generate random numbers with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for _ in range(3):
|
||||
print(mx.random.uniform())
|
||||
|
||||
which will print a sequence of unique pseudo random numbers. Alternatively you
|
||||
can explicitly set the key:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
key = mx.random.key(0)
|
||||
for _ in range(3):
|
||||
print(mx.random.uniform(key=key))
|
||||
|
||||
which will yield the same pseudo random number at each iteration.
|
||||
|
||||
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
||||
we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
|
||||
.. currentmodule:: mlx.core.random
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
seed
|
||||
key
|
||||
split
|
||||
bernoulli
|
||||
categorical
|
||||
gumbel
|
||||
normal
|
||||
randint
|
||||
uniform
|
||||
truncated_normal
|
Reference in New Issue
Block a user