mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
90d04072b7
commit
ee0c2835c5
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
src/python/_autosummary*/
|
src/python/_autosummary*/
|
||||||
|
src/python/nn/_autosummary*/
|
||||||
|
@ -61,7 +61,10 @@ set:
|
|||||||
def eval_fn(model, X, y):
|
def eval_fn(model, X, y):
|
||||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||||
|
|
||||||
Next, setup the problem parameters and load the data:
|
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||||
|
`mnist data loader
|
||||||
|
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||||
|
we will import as `mnist`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -35,8 +35,7 @@ Probably you are using a non-native Python. The output of
|
|||||||
|
|
||||||
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
||||||
are using a non-native Python. Switch your Python to a native Python. A good
|
are using a non-native Python. Switch your Python to a native Python. A good
|
||||||
way to do this is with
|
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
|
||||||
`Conda <https://stackoverflow.com/questions/65415996/how-to-specify-the-architecture-or-platform-for-a-new-conda-environment-apple>`_.
|
|
||||||
|
|
||||||
|
|
||||||
Build from source
|
Build from source
|
||||||
@ -166,3 +165,27 @@ should point to the path to the built metal library.
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
Troubleshooting
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Metal not found
|
||||||
|
~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
You see the following error when you try to build:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
error: unable to find utility "metal", not a developer tool or in PATH
|
||||||
|
|
||||||
|
To fix this, first make sure you have Xcode installed:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
xcode-select --install
|
||||||
|
|
||||||
|
Then set the active developer directory:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||||
|
@ -64,7 +64,6 @@ Quick Start with Neural Networks
|
|||||||
# gradient with respect to `mlp.trainable_parameters()`
|
# gradient with respect to `mlp.trainable_parameters()`
|
||||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||||
|
|
||||||
|
|
||||||
.. _module_class:
|
.. _module_class:
|
||||||
|
|
||||||
The Module Class
|
The Module Class
|
||||||
@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
|
|||||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||||
the parameters of a module and its submodules.
|
the parameters of a module and its submodules.
|
||||||
|
|
||||||
A :class:`Module` can also keep track of "frozen" parameters.
|
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||||
:meth:`Module.trainable_parameters` returns only the subset of
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||||
:meth:`Module.parameters` that is not frozen. When using
|
the gradients returned will be with respect to these trainable parameters.
|
||||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
|
||||||
trainable parameters.
|
|
||||||
|
|
||||||
Updating the parameters
|
|
||||||
|
Updating the Parameters
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
MLX modules allow accessing and updating individual parameters. However, most
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
times we need to update large subsets of a module's parameters. This action is
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
performed by :meth:`Module.update`.
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
Value and grad
|
|
||||||
|
Inspecting Modules
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The simplest way to see the model architecture is to print it. Following along with
|
||||||
|
the above example, you can print the ``MLP`` with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
print(mlp)
|
||||||
|
|
||||||
|
This will display:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
MLP(
|
||||||
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||||
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||||
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||||
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||||
|
all the parameters in a :class:`Module` do:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||||
|
|
||||||
|
As another example, you can count the number of parameters in a :class:`Module`
|
||||||
|
with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
Value and Grad
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
Using a :class:`Module` does not preclude using MLX's high order function
|
Using a :class:`Module` does not preclude using MLX's high order function
|
||||||
@ -133,62 +170,14 @@ In detail:
|
|||||||
:meth:`mlx.core.value_and_grad`
|
:meth:`mlx.core.value_and_grad`
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
|
:recursive:
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
|
Module
|
||||||
|
|
||||||
Neural Network Layers
|
.. toctree::
|
||||||
---------------------
|
|
||||||
|
|
||||||
.. autosummary::
|
nn/layers
|
||||||
:toctree: _autosummary
|
nn/functions
|
||||||
:template: nn-module-template.rst
|
nn/losses
|
||||||
|
|
||||||
Embedding
|
|
||||||
ReLU
|
|
||||||
PReLU
|
|
||||||
GELU
|
|
||||||
SiLU
|
|
||||||
Step
|
|
||||||
SELU
|
|
||||||
Mish
|
|
||||||
Linear
|
|
||||||
Conv1d
|
|
||||||
Conv2d
|
|
||||||
LayerNorm
|
|
||||||
RMSNorm
|
|
||||||
GroupNorm
|
|
||||||
RoPE
|
|
||||||
MultiHeadAttention
|
|
||||||
Sequential
|
|
||||||
|
|
||||||
Layers without parameters (e.g. activation functions) are also provided as
|
|
||||||
simple functions.
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary_functions
|
|
||||||
:template: nn-module-template.rst
|
|
||||||
|
|
||||||
gelu
|
|
||||||
gelu_approx
|
|
||||||
gelu_fast_approx
|
|
||||||
relu
|
|
||||||
prelu
|
|
||||||
silu
|
|
||||||
step
|
|
||||||
selu
|
|
||||||
mish
|
|
||||||
|
|
||||||
Loss Functions
|
|
||||||
--------------
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary_functions
|
|
||||||
:template: nn-module-template.rst
|
|
||||||
|
|
||||||
losses.cross_entropy
|
|
||||||
losses.binary_cross_entropy
|
|
||||||
losses.l1_loss
|
|
||||||
losses.mse_loss
|
|
||||||
losses.nll_loss
|
|
||||||
losses.kl_div_loss
|
|
||||||
|
23
docs/src/python/nn/functions.rst
Normal file
23
docs/src/python/nn/functions.rst
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
.. _nn_functions:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
Layers without parameters (e.g. activation functions) are also provided as
|
||||||
|
simple functions.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
gelu
|
||||||
|
gelu_approx
|
||||||
|
gelu_fast_approx
|
||||||
|
relu
|
||||||
|
prelu
|
||||||
|
silu
|
||||||
|
step
|
||||||
|
selu
|
||||||
|
mish
|
28
docs/src/python/nn/layers.rst
Normal file
28
docs/src/python/nn/layers.rst
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
.. _layers:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Layers
|
||||||
|
------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
Embedding
|
||||||
|
ReLU
|
||||||
|
PReLU
|
||||||
|
GELU
|
||||||
|
SiLU
|
||||||
|
Step
|
||||||
|
SELU
|
||||||
|
Mish
|
||||||
|
Linear
|
||||||
|
Conv1d
|
||||||
|
Conv2d
|
||||||
|
LayerNorm
|
||||||
|
RMSNorm
|
||||||
|
GroupNorm
|
||||||
|
RoPE
|
||||||
|
MultiHeadAttention
|
||||||
|
Sequential
|
17
docs/src/python/nn/losses.rst
Normal file
17
docs/src/python/nn/losses.rst
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
.. _losses:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.losses
|
||||||
|
|
||||||
|
Loss Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
cross_entropy
|
||||||
|
binary_cross_entropy
|
||||||
|
l1_loss
|
||||||
|
mse_loss
|
||||||
|
nll_loss
|
||||||
|
kl_div_loss
|
@ -1,7 +0,0 @@
|
|||||||
mlx.nn.Module
|
|
||||||
=============
|
|
||||||
|
|
||||||
.. currentmodule:: mlx.nn
|
|
||||||
|
|
||||||
.. autoclass:: Module
|
|
||||||
:members:
|
|
@ -7,12 +7,21 @@ from mlx.nn.layers.base import Module
|
|||||||
|
|
||||||
|
|
||||||
class Linear(Module):
|
class Linear(Module):
|
||||||
"""Applies an affine transformation to the input.
|
r"""Applies an affine transformation to the input.
|
||||||
|
|
||||||
|
Concretely:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = W^\top x + b
|
||||||
|
|
||||||
|
where :math:`W` has shape ``[output_dims, input_dims]``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dims (int): The dimensionality of the input features
|
input_dims (int): The dimensionality of the input features
|
||||||
output_dims (int): The dimensionality of the output features
|
output_dims (int): The dimensionality of the output features
|
||||||
bias (bool): If set to False then the layer will not use a bias
|
bias (bool, optional): If set to ``False`` then the layer will
|
||||||
|
not use a bias. Default ``True``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||||
|
Loading…
Reference in New Issue
Block a user