diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst
index 9482be725..9aae931a3 100644
--- a/docs/build/html/_sources/dev/extensions.rst
+++ b/docs/build/html/_sources/dev/extensions.rst
@@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image.
const std::vector& argnums) override;
/**
- * The primitive must know how to vectorize itself accross
+ * The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
diff --git a/docs/build/html/_sources/examples/mlp.rst b/docs/build/html/_sources/examples/mlp.rst
index c003618ce..36890e95c 100644
--- a/docs/build/html/_sources/examples/mlp.rst
+++ b/docs/build/html/_sources/examples/mlp.rst
@@ -61,7 +61,10 @@ set:
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
-Next, setup the problem parameters and load the data:
+Next, setup the problem parameters and load the data. To load the data, you need our
+`mnist data loader
+`_, which
+we will import as `mnist`.
.. code-block:: python
diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst
index 682f09f38..92669ab6e 100644
--- a/docs/build/html/_sources/install.rst
+++ b/docs/build/html/_sources/install.rst
@@ -15,11 +15,11 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
-- MacOS >= 13.3
+- macOS >= 13.3
.. note::
- MLX is only available on devices running MacOS >= 13.3
- It is highly recommended to use MacOS 14 (Sonoma)
+ MLX is only available on devices running macOS >= 13.3
+ It is highly recommended to use macOS 14 (Sonoma)
Troubleshooting
^^^^^^^^^^^^^^^
@@ -35,8 +35,7 @@ Probably you are using a non-native Python. The output of
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
-way to do this is with
-`Conda `_.
+way to do this is with `Conda `_.
Build from source
@@ -47,7 +46,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake `_ -- version 3.24 or later, and ``make``
-- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above)
+- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
Python API
@@ -88,6 +87,13 @@ To make sure the install is working run the tests with:
pip install ".[testing]"
python -m unittest discover python/tests
+Optional: Install stubs to enable auto completions and type checking from your IDE:
+
+.. code-block:: shell
+
+ pip install ".[dev]"
+ python setup.py generate_stubs
+
C++ API
^^^^^^^
@@ -154,8 +160,32 @@ should point to the path to the built metal library.
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
- MacOS SDK will be used
+ macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
+
+Troubleshooting
+^^^^^^^^^^^^^^^
+
+Metal not found
+~~~~~~~~~~~~~~~
+
+You see the following error when you try to build:
+
+.. code-block:: shell
+
+ error: unable to find utility "metal", not a developer tool or in PATH
+
+To fix this, first make sure you have Xcode installed:
+
+.. code-block:: shell
+
+ xcode-select --install
+
+Then set the active developer directory:
+
+.. code-block:: shell
+
+ sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst
index 21e66b5e4..a93bbadcd 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst
@@ -26,6 +26,7 @@
~array.cumprod
~array.cumsum
~array.exp
+ ~array.flatten
~array.item
~array.log
~array.log10
@@ -35,6 +36,7 @@
~array.max
~array.mean
~array.min
+ ~array.moveaxis
~array.prod
~array.reciprocal
~array.reshape
@@ -45,6 +47,7 @@
~array.square
~array.squeeze
~array.sum
+ ~array.swapaxes
~array.tolist
~array.transpose
~array.var
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst
new file mode 100644
index 000000000..bbd0a6656
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst
@@ -0,0 +1,6 @@
+mlx.core.ceil
+=============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: ceil
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst
new file mode 100644
index 000000000..90470d914
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst
@@ -0,0 +1,6 @@
+mlx.core.flatten
+================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: flatten
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst
new file mode 100644
index 000000000..a05f6d451
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst
@@ -0,0 +1,6 @@
+mlx.core.floor
+==============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: floor
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst
new file mode 100644
index 000000000..ed69d670c
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst
@@ -0,0 +1,6 @@
+mlx.core.moveaxis
+=================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: moveaxis
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
new file mode 100644
index 000000000..c0b518497
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
@@ -0,0 +1,6 @@
+mlx.core.simplify
+=================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: simplify
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst
new file mode 100644
index 000000000..fdb8721a2
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst
@@ -0,0 +1,6 @@
+mlx.core.stack
+==============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: stack
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst
new file mode 100644
index 000000000..07b724a0f
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst
@@ -0,0 +1,6 @@
+mlx.core.swapaxes
+=================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: swapaxes
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst
new file mode 100644
index 000000000..ef760035b
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst
@@ -0,0 +1,6 @@
+mlx.core.tri
+============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: tri
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst
new file mode 100644
index 000000000..89b45b090
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst
@@ -0,0 +1,6 @@
+mlx.core.tril
+=============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: tril
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst
new file mode 100644
index 000000000..1d6aa7626
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst
@@ -0,0 +1,6 @@
+mlx.core.triu
+=============
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: triu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst
new file mode 100644
index 000000000..79f55b253
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst
@@ -0,0 +1,58 @@
+mlx.nn.Module
+=============
+
+.. currentmodule:: mlx.nn
+
+.. autoclass:: Module
+
+
+ .. automethod:: __init__
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~Module.__init__
+ ~Module.apply
+ ~Module.apply_to_modules
+ ~Module.children
+ ~Module.clear
+ ~Module.copy
+ ~Module.eval
+ ~Module.filter_and_map
+ ~Module.freeze
+ ~Module.fromkeys
+ ~Module.get
+ ~Module.is_module
+ ~Module.items
+ ~Module.keys
+ ~Module.leaf_modules
+ ~Module.load_weights
+ ~Module.modules
+ ~Module.named_modules
+ ~Module.parameters
+ ~Module.pop
+ ~Module.popitem
+ ~Module.save_weights
+ ~Module.setdefault
+ ~Module.train
+ ~Module.trainable_parameter_filter
+ ~Module.trainable_parameters
+ ~Module.unfreeze
+ ~Module.update
+ ~Module.valid_child_filter
+ ~Module.valid_parameter_filter
+ ~Module.values
+
+
+
+
+
+ .. rubric:: Attributes
+
+ .. autosummary::
+
+ ~Module.training
+
+
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
new file mode 100644
index 000000000..2ea7cda8a
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
@@ -0,0 +1,18 @@
+mlx.optimizers.AdaDelta
+=======================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: AdaDelta
+
+
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~AdaDelta.__init__
+ ~AdaDelta.apply_single
+
+
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
new file mode 100644
index 000000000..8a12fc43c
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
@@ -0,0 +1,18 @@
+mlx.optimizers.Adagrad
+======================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: Adagrad
+
+
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~Adagrad.__init__
+ ~Adagrad.apply_single
+
+
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst
new file mode 100644
index 000000000..b5259844f
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst
@@ -0,0 +1,18 @@
+mlx.optimizers.AdamW
+====================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: AdamW
+
+
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~AdamW.__init__
+ ~AdamW.apply_single
+
+
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
new file mode 100644
index 000000000..58e6c95ca
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
@@ -0,0 +1,18 @@
+mlx.optimizers.Adamax
+=====================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: Adamax
+
+
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~Adamax.__init__
+ ~Adamax.apply_single
+
+
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
new file mode 100644
index 000000000..217b4619f
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
@@ -0,0 +1,18 @@
+mlx.optimizers.RMSprop
+======================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: RMSprop
+
+
+
+
+ .. rubric:: Methods
+
+ .. autosummary::
+
+ ~RMSprop.__init__
+ ~RMSprop.apply_single
+
+
diff --git a/docs/build/html/_sources/python/nn.rst b/docs/build/html/_sources/python/nn.rst
index 93cfd8c78..bc19a8162 100644
--- a/docs/build/html/_sources/python/nn.rst
+++ b/docs/build/html/_sources/python/nn.rst
@@ -64,7 +64,6 @@ Quick Start with Neural Networks
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
-
.. _module_class:
The Module Class
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
-A :class:`Module` can also keep track of "frozen" parameters.
-:meth:`Module.trainable_parameters` returns only the subset of
-:meth:`Module.parameters` that is not frozen. When using
-:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
-trainable parameters.
+A :class:`Module` can also keep track of "frozen" parameters. See the
+:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
+the gradients returned will be with respect to these trainable parameters.
-Updating the parameters
+
+Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
-Value and grad
+
+Inspecting Modules
+^^^^^^^^^^^^^^^^^^
+
+The simplest way to see the model architecture is to print it. Following along with
+the above example, you can print the ``MLP`` with:
+
+.. code-block:: python
+
+ print(mlp)
+
+This will display:
+
+.. code-block:: shell
+
+ MLP(
+ (layers.0): Linear(input_dims=2, output_dims=128, bias=True)
+ (layers.1): Linear(input_dims=128, output_dims=128, bias=True)
+ (layers.2): Linear(input_dims=128, output_dims=10, bias=True)
+ )
+
+To get more detailed information on the arrays in a :class:`Module` you can use
+:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
+all the parameters in a :class:`Module` do:
+
+.. code-block:: python
+
+ from mlx.utils import tree_map
+ shapes = tree_map(lambda p: p.shape, mlp.parameters())
+
+As another example, you can count the number of parameters in a :class:`Module`
+with:
+
+.. code-block:: python
+
+ from mlx.utils import tree_flatten
+ num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
+
+
+Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
@@ -133,62 +170,14 @@ In detail:
:meth:`mlx.core.value_and_grad`
.. autosummary::
+ :recursive:
:toctree: _autosummary
value_and_grad
+ Module
-Neural Network Layers
----------------------
+.. toctree::
-.. autosummary::
- :toctree: _autosummary
- :template: nn-module-template.rst
-
- Embedding
- ReLU
- PReLU
- GELU
- SiLU
- Step
- SELU
- Mish
- Linear
- Conv1d
- Conv2d
- LayerNorm
- RMSNorm
- GroupNorm
- RoPE
- MultiHeadAttention
- Sequential
-
-Layers without parameters (e.g. activation functions) are also provided as
-simple functions.
-
-.. autosummary::
- :toctree: _autosummary_functions
- :template: nn-module-template.rst
-
- gelu
- gelu_approx
- gelu_fast_approx
- relu
- prelu
- silu
- step
- selu
- mish
-
-Loss Functions
---------------
-
-.. autosummary::
- :toctree: _autosummary_functions
- :template: nn-module-template.rst
-
- losses.cross_entropy
- losses.binary_cross_entropy
- losses.l1_loss
- losses.mse_loss
- losses.nll_loss
- losses.kl_div_loss
+ nn/layers
+ nn/functions
+ nn/losses
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst
rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst
diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst
similarity index 100%
rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst
rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst
diff --git a/docs/build/html/_sources/python/nn/functions.rst b/docs/build/html/_sources/python/nn/functions.rst
new file mode 100644
index 000000000..f13cbe7b4
--- /dev/null
+++ b/docs/build/html/_sources/python/nn/functions.rst
@@ -0,0 +1,23 @@
+.. _nn_functions:
+
+.. currentmodule:: mlx.nn
+
+Functions
+---------
+
+Layers without parameters (e.g. activation functions) are also provided as
+simple functions.
+
+.. autosummary::
+ :toctree: _autosummary_functions
+ :template: nn-module-template.rst
+
+ gelu
+ gelu_approx
+ gelu_fast_approx
+ relu
+ prelu
+ silu
+ step
+ selu
+ mish
diff --git a/docs/build/html/_sources/python/nn/layers.rst b/docs/build/html/_sources/python/nn/layers.rst
new file mode 100644
index 000000000..5628134d6
--- /dev/null
+++ b/docs/build/html/_sources/python/nn/layers.rst
@@ -0,0 +1,28 @@
+.. _layers:
+
+.. currentmodule:: mlx.nn
+
+Layers
+------
+
+.. autosummary::
+ :toctree: _autosummary
+ :template: nn-module-template.rst
+
+ Embedding
+ ReLU
+ PReLU
+ GELU
+ SiLU
+ Step
+ SELU
+ Mish
+ Linear
+ Conv1d
+ Conv2d
+ LayerNorm
+ RMSNorm
+ GroupNorm
+ RoPE
+ MultiHeadAttention
+ Sequential
diff --git a/docs/build/html/_sources/python/nn/losses.rst b/docs/build/html/_sources/python/nn/losses.rst
new file mode 100644
index 000000000..4808ce5ab
--- /dev/null
+++ b/docs/build/html/_sources/python/nn/losses.rst
@@ -0,0 +1,17 @@
+.. _losses:
+
+.. currentmodule:: mlx.nn.losses
+
+Loss Functions
+--------------
+
+.. autosummary::
+ :toctree: _autosummary_functions
+ :template: nn-module-template.rst
+
+ cross_entropy
+ binary_cross_entropy
+ l1_loss
+ mse_loss
+ nll_loss
+ kl_div_loss
diff --git a/docs/build/html/_sources/python/nn/module.rst b/docs/build/html/_sources/python/nn/module.rst
deleted file mode 100644
index e14ba96f4..000000000
--- a/docs/build/html/_sources/python/nn/module.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-mlx.nn.Module
-=============
-
-.. currentmodule:: mlx.nn
-
-.. autoclass:: Module
- :members:
diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst
index b9a4c9066..ea25b90f9 100644
--- a/docs/build/html/_sources/python/ops.rst
+++ b/docs/build/html/_sources/python/ops.rst
@@ -26,6 +26,7 @@ Operations
argsort
array_equal
broadcast_to
+ ceil
concatenate
convolve
conv1d
@@ -39,6 +40,8 @@ Operations
exp
expand_dims
eye
+ floor
+ flatten
full
greater
greater_equal
@@ -59,6 +62,7 @@ Operations
mean
min
minimum
+ moveaxis
multiply
negative
ones
@@ -82,14 +86,19 @@ Operations
sqrt
square
squeeze
+ stack
stop_gradient
subtract
sum
+ swapaxes
take
take_along_axis
tan
tanh
transpose
+ tri
+ tril
+ triu
var
where
zeros
diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst
index 7f5d3a067..b8e5cfea7 100644
--- a/docs/build/html/_sources/python/optimizers.rst
+++ b/docs/build/html/_sources/python/optimizers.rst
@@ -38,4 +38,9 @@ model's parameters and the **optimizer state**.
OptimizerState
Optimizer
SGD
+ RMSprop
+ Adagrad
+ AdaDelta
Adam
+ AdamW
+ Adamax
diff --git a/docs/build/html/_sources/python/transforms.rst b/docs/build/html/_sources/python/transforms.rst
index cc8d681d5..fa6d1d701 100644
--- a/docs/build/html/_sources/python/transforms.rst
+++ b/docs/build/html/_sources/python/transforms.rst
@@ -14,3 +14,4 @@ Transforms
jvp
vjp
vmap
+ simplify
diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html
index cd7666c91..a6e660fe2 100644
--- a/docs/build/html/cpp/ops.html
+++ b/docs/build/html/cpp/ops.html
@@ -226,6 +226,7 @@
@@ -720,7 +745,7 @@ back and go to our example to give ourselves a more concrete image.
conststd::vector<int>&argnums)override;/**
- * The primitive must know how to vectorize itself accross
+ * The primitive must know how to vectorize itself across * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension.
@@ -1445,7 +1470,7 @@ with the naive
We see some modest improvements right away!
This operation is now good to be used to build other operations,
-in mlx.nn.Module calls, and also as a part of graph
+in mlx.nn.Module calls, and also as a part of graph
transformations such as grad() and simplify()!
@@ -591,8 +616,8 @@ module to concisely define the model architecture.
positional encoding. [1] In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
-
Our implementation uses mlx.nn.Linear for all the projections and
-mlx.nn.RoPE for the positional encoding.
+
Our implementation uses mlx.nn.Linear for all the projections and
+mlx.nn.RoPE for the positional encoding.
importmlx.coreasmximportmlx.nnasnn
@@ -650,7 +675,7 @@ support efficient inference.
The other component of the Llama model is the encoder layer which uses RMS
normalization [2] and SwiGLU. [3] For RMS normalization we will use
-mlx.nn.RMSNorm that is already provided in mlx.nn.
@@ -568,11 +593,11 @@ multi-layer perceptron to classify MNIST.
The model is defined as the MLP class which inherits from
-mlx.nn.Module. We follow the standard idiom to make a new module:
+mlx.nn.Module. We follow the standard idiom to make a new module:
Define an __init__ where the parameters and/or submodules are setup. See
the Module class docs for more information on how
-mlx.nn.Module registers parameters.
should be arm. If it is i386 (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
-way to do this is with
-Conda.
+way to do this is with Conda.
@@ -615,7 +643,7 @@ way to do this is with
A C++ compiler with C++17 support (e.g. Clang >= 5.0)
@@ -671,8 +704,8 @@ cmake..&&
directory as the executable statically linked to libmlx.a or the
preprocessor constant METAL_PATH should be defined at build time and it
should point to the path to the built metal library.
-