diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo
index 4c56d0ae4..0241e00e4 100644
--- a/docs/build/html/.buildinfo
+++ b/docs/build/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: df22fdae6eaa6299681f0aab7c5d6029
+config: b49cb089891263e82aedf5bc4cacbe8a
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst
index a7880e396..3563305bf 100644
--- a/docs/build/html/_sources/dev/extensions.rst
+++ b/docs/build/html/_sources/dev/extensions.rst
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python
^^^^^^^^^^^^^^^^^^
-We use PyBind11_ to build a Python API for the C++ library. Since bindings
-for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
-are already provided, adding our :meth:`axpby` becomes very simple!
+We use PyBind11_ to build a Python API for the C++ library. Since bindings for
+components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
+already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
@@ -927,18 +927,18 @@ Results:
We see some modest improvements right away!
-This operation is now good to be used to build other operations,
-in :class:`mlx.nn.Module` calls, and also as a part of graph
-transformations like :meth:`grad`!
+This operation is now good to be used to build other operations, in
+:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
+:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
- The full example code is available in `mlx-examples `_.
+ The full example code is available in `mlx `_.
-.. code: `TODO_LINK/extensions`_
+.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
diff --git a/docs/build/html/_sources/index.rst b/docs/build/html/_sources/index.rst
index 4f4411758..50dfe9083 100644
--- a/docs/build/html/_sources/index.rst
+++ b/docs/build/html/_sources/index.rst
@@ -41,6 +41,7 @@ are the CPU and GPU.
usage/indexing
usage/saving_and_loading
usage/function_transforms
+ usage/compile
usage/numpy
usage/using_streams
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst
new file mode 100644
index 000000000..3ccea976d
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst
@@ -0,0 +1,6 @@
+mlx.core.compile
+================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: compile
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst
new file mode 100644
index 000000000..913574b97
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst
@@ -0,0 +1,6 @@
+mlx.core.disable\_compile
+=========================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: disable_compile
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst
new file mode 100644
index 000000000..c991ee8cb
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst
@@ -0,0 +1,6 @@
+mlx.core.enable\_compile
+========================
+
+.. currentmodule:: mlx.core
+
+.. autofunction:: enable_compile
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
deleted file mode 100644
index c0b518497..000000000
--- a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.core.simplify
-=================
-
-.. currentmodule:: mlx.core
-
-.. autofunction:: simplify
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
index 2ea7cda8a..55792c434 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
@@ -14,5 +14,6 @@
~AdaDelta.__init__
~AdaDelta.apply_single
+ ~AdaDelta.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst
index b0e5e5c30..9047eea41 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst
@@ -14,5 +14,6 @@
~Adafactor.__init__
~Adafactor.apply_single
+ ~Adafactor.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
index 8a12fc43c..c12713e8a 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
@@ -14,5 +14,6 @@
~Adagrad.__init__
~Adagrad.apply_single
+ ~Adagrad.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst
index 074080ea6..9ca26adfa 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst
@@ -14,5 +14,6 @@
~Adam.__init__
~Adam.apply_single
+ ~Adam.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
index 58e6c95ca..73dc7314d 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
@@ -14,5 +14,6 @@
~Adamax.__init__
~Adamax.apply_single
+ ~Adamax.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst
index a00dc50f0..1454aada1 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst
@@ -14,5 +14,6 @@
~Lion.__init__
~Lion.apply_single
+ ~Lion.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst
new file mode 100644
index 000000000..763eeb293
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst
@@ -0,0 +1,6 @@
+mlx.optimizers.Optimizer.apply\_gradients
+=========================================
+
+.. currentmodule:: mlx.optimizers
+
+.. automethod:: Optimizer.apply_gradients
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst
new file mode 100644
index 000000000..e0245cf02
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst
@@ -0,0 +1,6 @@
+mlx.optimizers.Optimizer.init
+=============================
+
+.. currentmodule:: mlx.optimizers
+
+.. automethod:: Optimizer.init
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst
deleted file mode 100644
index 613eb02cf..000000000
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-mlx.optimizers.Optimizer
-========================
-
-.. currentmodule:: mlx.optimizers
-
-.. autoclass:: Optimizer
-
-
-
-
- .. rubric:: Methods
-
- .. autosummary::
-
- ~Optimizer.__init__
- ~Optimizer.apply_gradients
- ~Optimizer.apply_single
- ~Optimizer.update
-
-
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst
new file mode 100644
index 000000000..e0bf31dbe
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst
@@ -0,0 +1,6 @@
+mlx.optimizers.Optimizer.state
+==============================
+
+.. currentmodule:: mlx.optimizers
+
+.. autoproperty:: Optimizer.state
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst
new file mode 100644
index 000000000..e7610999e
--- /dev/null
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst
@@ -0,0 +1,6 @@
+mlx.optimizers.Optimizer.update
+===============================
+
+.. currentmodule:: mlx.optimizers
+
+.. automethod:: Optimizer.update
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst
deleted file mode 100644
index b319b6d09..000000000
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-mlx.optimizers.OptimizerState
-=============================
-
-.. currentmodule:: mlx.optimizers
-
-.. autoclass:: OptimizerState
-
-
-
-
- .. rubric:: Methods
-
- .. autosummary::
-
- ~OptimizerState.get
-
-
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
index 217b4619f..d9ba20078 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
@@ -14,5 +14,6 @@
~RMSprop.__init__
~RMSprop.apply_single
+ ~RMSprop.init_single
diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst
index 35a9759ed..4b6f397ec 100644
--- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst
+++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst
@@ -14,5 +14,6 @@
~SGD.__init__
~SGD.apply_single
+ ~SGD.init_single
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst
index 284b453cf..9159bb888 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: ALiBi
-
-
\ No newline at end of file
+.. autoclass:: ALiBi
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst
index b94d82e7f..d085d5af5 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: BatchNorm
-
-
\ No newline at end of file
+.. autoclass:: BatchNorm
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst
index c4128b83b..0fb6ff201 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Conv1d
-
-
\ No newline at end of file
+.. autoclass:: Conv1d
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst
index 7bd1f08bb..566e5d1e1 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Conv2d
-
-
\ No newline at end of file
+.. autoclass:: Conv2d
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst
index d1a68e793..2ec3556e1 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Dropout
-
-
\ No newline at end of file
+.. autoclass:: Dropout
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst
index 8bf18deb8..d643adcb9 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Dropout2d
-
-
\ No newline at end of file
+.. autoclass:: Dropout2d
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst
index d513a3d61..f386030ee 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Dropout3d
-
-
\ No newline at end of file
+.. autoclass:: Dropout3d
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst
index ad2f3f2ce..0f29f593d 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Embedding
-
-
\ No newline at end of file
+.. autoclass:: Embedding
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst
index c963c84f2..c6ca7a28c 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: GELU
-
-
\ No newline at end of file
+.. autoclass:: GELU
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst
index 762d9ffea..982103df5 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: GroupNorm
-
-
\ No newline at end of file
+.. autoclass:: GroupNorm
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst
index 92152b356..66d01967f 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: InstanceNorm
-
-
\ No newline at end of file
+.. autoclass:: InstanceNorm
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst
index cc0ac117d..817f9551e 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: LayerNorm
-
-
\ No newline at end of file
+.. autoclass:: LayerNorm
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst
index 627e6e6e6..53be170e4 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Linear
-
-
\ No newline at end of file
+.. autoclass:: Linear
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst
index bf5397852..bd10864be 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Mish
-
-
\ No newline at end of file
+.. autoclass:: Mish
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst
new file mode 100644
index 000000000..7f4819837
--- /dev/null
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst
@@ -0,0 +1,6 @@
+mlx.nn.Module.state
+===================
+
+.. currentmodule:: mlx.nn
+
+.. autoproperty:: Module.state
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst
index 2c3a8fcc1..0a3f8d184 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: MultiHeadAttention
-
-
\ No newline at end of file
+.. autoclass:: MultiHeadAttention
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst
index 2de33a688..4583c2d65 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: PReLU
-
-
\ No newline at end of file
+.. autoclass:: PReLU
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst
index ccbde4340..00688282e 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: QuantizedLinear
-
-
\ No newline at end of file
+.. autoclass:: QuantizedLinear
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst
index 474b1355d..d4501bd36 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: RMSNorm
-
-
\ No newline at end of file
+.. autoclass:: RMSNorm
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst
index 944917de9..6707e757e 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: ReLU
-
-
\ No newline at end of file
+.. autoclass:: ReLU
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst
index 392fbab7b..fca09a4eb 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: RoPE
-
-
\ No newline at end of file
+.. autoclass:: RoPE
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst
index 9fe57cdea..fa7477246 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: SELU
-
-
\ No newline at end of file
+.. autoclass:: SELU
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst
index af6ee04ab..5ae61b025 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Sequential
-
-
\ No newline at end of file
+.. autoclass:: Sequential
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst
index 85069c9d5..57d18df4f 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: SiLU
-
-
\ No newline at end of file
+.. autoclass:: SiLU
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst
index bfdd633a5..30b7a1f90 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: SinusoidalPositionalEncoding
-
-
\ No newline at end of file
+.. autoclass:: SinusoidalPositionalEncoding
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst
index 464c3451b..5e17a5199 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Softshrink
-
-
\ No newline at end of file
+.. autoclass:: Softshrink
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst
index 688313628..204f8cbb1 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Step
-
-
\ No newline at end of file
+.. autoclass:: Step
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst
index 01dc3a841..f7e800eff 100644
--- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: Transformer
-
-
\ No newline at end of file
+.. autoclass:: Transformer
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst
index 3e1668eb6..616cb1e22 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: gelu
-
-
\ No newline at end of file
+.. autofunction:: gelu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst
index de08dc88c..d634ee1de 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: gelu_approx
-
-
\ No newline at end of file
+.. autofunction:: gelu_approx
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
index c84114e6c..36cc04480 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: gelu_fast_approx
-
-
\ No newline at end of file
+.. autofunction:: gelu_fast_approx
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst
deleted file mode 100644
index e61e02905..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.constant
-====================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: constant
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst
deleted file mode 100644
index b500f578d..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.glorot\_normal
-==========================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: glorot_normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst
deleted file mode 100644
index b266fc94a..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.glorot\_uniform
-===========================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: glorot_uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst
deleted file mode 100644
index 51c3287a7..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.he\_normal
-======================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: he_normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst
deleted file mode 100644
index ee299e247..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.he\_uniform
-=======================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: he_uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst
deleted file mode 100644
index a5772adfa..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.identity
-====================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: identity
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst
deleted file mode 100644
index 6f9ce0023..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.normal
-==================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst
deleted file mode 100644
index 7d3b82560..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.init.uniform
-===================
-
-.. currentmodule:: mlx.nn.init
-
-.. autofunction:: uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst
deleted file mode 100644
index 7e983ec9c..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.constant
-============================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: constant
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst
deleted file mode 100644
index 1860f0f1a..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.glorot\_normal
-==================================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: glorot_normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst
deleted file mode 100644
index 1693bb019..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.glorot\_uniform
-===================================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: glorot_uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst
deleted file mode 100644
index 76e5d0ac7..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.he\_normal
-==============================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: he_normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst
deleted file mode 100644
index 7482519a1..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.he\_uniform
-===============================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: he_uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst
deleted file mode 100644
index 8548c4439..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.identity
-============================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: identity
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst
deleted file mode 100644
index 3e82a3645..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.normal
-==========================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: normal
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst
deleted file mode 100644
index 28c504bd1..000000000
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-mlx.nn.initializers.uniform
-===========================
-
-.. currentmodule:: mlx.nn.initializers
-
-.. autofunction:: uniform
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
index be553e4c0..ba5254eff 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: binary_cross_entropy
-
-
\ No newline at end of file
+.. autofunction:: binary_cross_entropy
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst
index 7970aaca7..27b5d4a8e 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: cosine_similarity_loss
-
-
\ No newline at end of file
+.. autofunction:: cosine_similarity_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
index 9c50fd349..fd7f9e6f6 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: cross_entropy
-
-
\ No newline at end of file
+.. autofunction:: cross_entropy
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst
index 63cc52978..a481e2317 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: gaussian_nll_loss
-
-
\ No newline at end of file
+.. autofunction:: gaussian_nll_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst
index 3b94ae64c..092dcd383 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: hinge_loss
-
-
\ No newline at end of file
+.. autofunction:: hinge_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst
index 5b5dc918e..da5e4d417 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: huber_loss
-
-
\ No newline at end of file
+.. autofunction:: huber_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
index 11e070650..04d2fcce3 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: kl_div_loss
-
-
\ No newline at end of file
+.. autofunction:: kl_div_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst
index 34ae66d69..950aff725 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: l1_loss
-
-
\ No newline at end of file
+.. autofunction:: l1_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst
index b00c1a51f..b7a7461c9 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: log_cosh_loss
-
-
\ No newline at end of file
+.. autofunction:: log_cosh_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst
new file mode 100644
index 000000000..ede1b9f0a
--- /dev/null
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst
@@ -0,0 +1,6 @@
+mlx.nn.losses.margin\_ranking\_loss
+===================================
+
+.. currentmodule:: mlx.nn.losses
+
+.. autofunction:: margin_ranking_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst
index 534ed1e14..c4e44ddb1 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: mse_loss
-
-
\ No newline at end of file
+.. autofunction:: mse_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst
index c94eb82a1..e64b55dfc 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: nll_loss
-
-
\ No newline at end of file
+.. autofunction:: nll_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst
index 00a647a75..d96bb5823 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: smooth_l1_loss
-
-
\ No newline at end of file
+.. autofunction:: smooth_l1_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst
index 4698d6155..f52eaab92 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
-.. autoclass:: triplet_loss
-
-
\ No newline at end of file
+.. autofunction:: triplet_loss
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst
index 85bf0899b..49c1cfcb9 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: mish
-
-
\ No newline at end of file
+.. autofunction:: mish
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst
index f3757c1c3..51085ec23 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: prelu
-
-
\ No newline at end of file
+.. autofunction:: prelu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst
index 93a69272a..f1c28d7aa 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: relu
-
-
\ No newline at end of file
+.. autofunction:: relu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst
index 00c1d0923..f1530a805 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: selu
-
-
\ No newline at end of file
+.. autofunction:: selu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst
index b30c17b06..cd5ff218e 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: silu
-
-
\ No newline at end of file
+.. autofunction:: silu
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst
index e6af930b6..b844f9242 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: softshrink
-
-
\ No newline at end of file
+.. autofunction:: softshrink
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst
index 1395bd012..0ad2c19e6 100644
--- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst
+++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst
@@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
-.. autoclass:: step
-
-
\ No newline at end of file
+.. autofunction:: step
\ No newline at end of file
diff --git a/docs/build/html/_sources/python/nn/initializers.rst b/docs/build/html/_sources/python/nn/initializers.rst
deleted file mode 100644
index 59dddbe22..000000000
--- a/docs/build/html/_sources/python/nn/initializers.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-.. _initializers:
-
-.. currentmodule:: mlx.nn.initializers
-
-Initializers
---------------
-
-.. autosummary::
- :toctree: _autosummary_functions
-
- constant
- normal
- uniform
- identity
- glorot_normal
- glorot_uniform
- he_normal
- he_uniform
diff --git a/docs/build/html/_sources/python/nn/losses.rst b/docs/build/html/_sources/python/nn/losses.rst
index 6c4327eb8..6a2e128c5 100644
--- a/docs/build/html/_sources/python/nn/losses.rst
+++ b/docs/build/html/_sources/python/nn/losses.rst
@@ -18,6 +18,7 @@ Loss Functions
kl_div_loss
l1_loss
log_cosh_loss
+ margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss
diff --git a/docs/build/html/_sources/python/nn/module.rst b/docs/build/html/_sources/python/nn/module.rst
index 042a88028..c3a4dfa62 100644
--- a/docs/build/html/_sources/python/nn/module.rst
+++ b/docs/build/html/_sources/python/nn/module.rst
@@ -11,6 +11,7 @@ Module
:toctree: _autosummary
Module.training
+ Module.state
.. rubric:: Methods
diff --git a/docs/build/html/_sources/python/optimizer.rst b/docs/build/html/_sources/python/optimizer.rst
new file mode 100644
index 000000000..cf6034dee
--- /dev/null
+++ b/docs/build/html/_sources/python/optimizer.rst
@@ -0,0 +1,23 @@
+Optimizer
+=========
+
+.. currentmodule:: mlx.optimizers
+
+.. autoclass:: Optimizer
+
+
+ .. rubric:: Attributes
+
+ .. autosummary::
+ :toctree: _autosummary
+
+ Optimizer.state
+
+ .. rubric:: Methods
+
+ .. autosummary::
+ :toctree: _autosummary
+
+ Optimizer.apply_gradients
+ Optimizer.init
+ Optimizer.update
diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst
index fe8632a7e..4ef43d50f 100644
--- a/docs/build/html/_sources/python/optimizers.rst
+++ b/docs/build/html/_sources/python/optimizers.rst
@@ -29,14 +29,16 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
+.. toctree::
+
+ optimizer
+
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
- OptimizerState
- Optimizer
SGD
RMSprop
Adagrad
diff --git a/docs/build/html/_sources/python/transforms.rst b/docs/build/html/_sources/python/transforms.rst
index cc8d681d5..ad9ba579b 100644
--- a/docs/build/html/_sources/python/transforms.rst
+++ b/docs/build/html/_sources/python/transforms.rst
@@ -9,6 +9,9 @@ Transforms
:toctree: _autosummary
eval
+ compile
+ disable_compile
+ enable_compile
grad
value_and_grad
jvp
diff --git a/docs/build/html/_sources/usage/compile.rst b/docs/build/html/_sources/usage/compile.rst
new file mode 100644
index 000000000..97d5503a3
--- /dev/null
+++ b/docs/build/html/_sources/usage/compile.rst
@@ -0,0 +1,430 @@
+.. _compile:
+
+Compilation
+===========
+
+.. currentmodule:: mlx.core
+
+MLX has a :func:`compile` function transformation which compiles computation
+graphs. Function compilation results in smaller graphs by merging common work
+and fusing certain operations. In many cases this can lead to big improvements
+in run-time and memory use.
+
+Getting started with :func:`compile` is simple, but there are some edge cases
+that are good to be aware of for more complex graphs and advanced usage.
+
+Basics of Compile
+-----------------
+
+Let's start with a simple example:
+
+.. code-block:: python
+
+ def fun(x, y):
+ return mx.exp(-x) + y
+
+ x = mx.array(1.0)
+ y = mx.array(2.0)
+
+ # Regular call, no compilation
+ # Prints: array(2.36788, dtype=float32)
+ print(fun(x, y))
+
+ # Compile the function
+ compiled_fun = mx.compile(fun)
+
+ # Prints: array(2.36788, dtype=float32)
+ print(compiled_fun(x, y))
+
+The output of both the regular function and the compiled function is the same
+up to numerical precision.
+
+The first time you call a compiled function, MLX will build the compute
+graph, optimize it, and generate and compile code. This can be relatively
+slow. However, MLX will cache compiled functions, so calling a compiled
+function multiple times will not initiate a new compilation. This means you
+should typically compile functions that you plan to use more than once.
+
+.. code-block:: python
+
+ def fun(x, y):
+ return mx.exp(-x) + y
+
+ x = mx.array(1.0)
+ y = mx.array(2.0)
+
+ compiled_fun = mx.compile(fun)
+
+ # Compiled here
+ compiled_fun(x, y)
+
+ # Not compiled again
+ compiled_fun(x, y)
+
+ # Not compiled again
+ mx.compile(fun)(x, y)
+
+There are some important cases to be aware of that can cause a function to
+be recompiled:
+
+* Changing the shape or number of dimensions
+* Changing the type of any of the inputs
+* Changing the number of inputs to the function
+
+In certain cases only some of the compilation stack will be rerun (for
+example when changing the shapes) and in other cases the full compilation
+stack will be rerun (for example when changing the types). In general you
+should avoid compiling functions too frequently.
+
+Another idiom to watch out for is compiling functions which get created and
+destroyed frequently. This can happen, for example, when compiling an anonymous
+function in a loop:
+
+.. code-block:: python
+
+ a = mx.array(1.0)
+ # Don't do this, compiles lambda at each iteration
+ for _ in range(5):
+ mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
+
+Example Speedup
+---------------
+
+The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
+Transformer-based models. The implementation involves several unary and binary
+element-wise operations:
+
+.. code-block:: python
+
+ def gelu(x):
+ return x * (1 + mx.erf(x / math.sqrt(2))) / 2
+
+If you use this function with small arrays, it will be overhead bound. If you
+use it with large arrays it will be memory bandwidth bound. However, all of
+the operations in the ``gelu`` are fusible into a single kernel with
+:func:`compile`. This can speedup both cases considerably.
+
+Let's compare the runtime of the regular function versus the compiled
+function. We'll use the following timing helper which does a warm up and
+handles synchronization:
+
+.. code-block:: python
+
+ import time
+
+ def timeit(fun, x):
+ # warm up
+ for _ in range(10):
+ mx.eval(fun(x))
+
+ tic = time.perf_counter()
+ for _ in range(100):
+ mx.eval(fun(x))
+ toc = time.perf_counter()
+ tpi = 1e3 * (toc - tic) / 100
+ print(f"Time per iteration {tpi:.3f} (ms)")
+
+
+Now make an array, and benchmark both functions:
+
+.. code-block:: python
+
+ x = mx.random.uniform(shape=(32, 1000, 4096))
+ timeit(nn.gelu, x)
+ timeit(mx.compile(nn.gelu), x)
+
+On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
+five times faster.
+
+.. note::
+
+ As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
+ functions can still be helpful, but won't typically result in as large a
+ speedup as compiling operations that run on the GPU.
+
+
+Debugging
+---------
+
+When a compiled function is first called, it is traced with placeholder
+inputs. This means you can't evaluate arrays (for example to print their
+contents) inside compiled functions.
+
+.. code-block:: python
+
+ @mx.compile
+ def fun(x):
+ z = -x
+ print(z) # Crash
+ return mx.exp(z)
+
+ fun(mx.array(5.0))
+
+For debugging, inspecting arrays can be helpful. One way to do that is to
+globally disable compilation using the :func:`disable_compile` function or
+``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
+``fun`` is compiled:
+
+.. code-block:: python
+
+ @mx.compile
+ def fun(x):
+ z = -x
+ print(z) # Okay
+ return mx.exp(z)
+
+ mx.disable_compile()
+ fun(mx.array(5.0))
+
+
+Pure Functions
+--------------
+
+Compiled functions are intended to be *pure*; that is they should not have side
+effects. For example:
+
+.. code-block:: python
+
+ state = []
+
+ @mx.compile
+ def fun(x, y):
+ z = x + y
+ state.append(z)
+ return mx.exp(z)
+
+ fun(mx.array(1.0), mx.array(2.0))
+ # Crash!
+ print(state)
+
+After the first call of ``fun``, the ``state`` list will hold a placeholder
+array. The placeholder does not have any data; it is only used to build the
+computation graph. Printing such an array results in a crash.
+
+You have two options to deal with this. The first option is to simply return
+``state`` as an output:
+
+.. code-block:: python
+
+ state = []
+
+ @mx.compile
+ def fun(x, y):
+ z = x + y
+ state.append(z)
+ return mx.exp(z), state
+
+ _, state = fun(mx.array(1.0), mx.array(2.0))
+ # Prints [array(3, dtype=float32)]
+ print(state)
+
+In some cases returning updated state can be pretty inconvenient. Hence,
+:func:`compile` has a parameter to capture implicit outputs:
+
+.. code-block:: python
+
+ from functools import partial
+
+ state = []
+
+ # Tell compile to capture state as an output
+ @partial(mx.compile, outputs=state)
+ def fun(x, y):
+ z = x + y
+ state.append(z)
+ return mx.exp(z), state
+
+ fun(mx.array(1.0), mx.array(2.0))
+ # Prints [array(3, dtype=float32)]
+ print(state)
+
+This is particularly useful for compiling a function which includes an update
+to a container of arrays, as is commonly done when training the parameters of a
+:class:`mlx.nn.Module`.
+
+Compiled functions will also treat any inputs not in the parameter list as
+constants. For example:
+
+.. code-block:: python
+
+ state = [mx.array(1.0)]
+
+ @mx.compile
+ def fun(x):
+ return x + state[0]
+
+ # Prints array(2, dtype=float32)
+ print(fun(mx.array(1.0)))
+
+ # Update state
+ state[0] = mx.array(5.0)
+
+ # Still prints array(2, dtype=float32)
+ print(fun(mx.array(1.0)))
+
+In order to have the change of state reflected in the outputs of ``fun`` you
+again have two options. The first option is to simply pass ``state`` as input
+to the function. In some cases this can be pretty inconvenient. Hence,
+:func:`compile` also has a parameter to capture implicit inputs:
+
+.. code-block:: python
+
+ from functools import partial
+ state = [mx.array(1.0)]
+
+ # Tell compile to capture state as an input
+ @partial(mx.compile, inputs=state)
+ def fun(x):
+ return x + state[0]
+
+ # Prints array(2, dtype=float32)
+ print(fun(mx.array(1.0)))
+
+ # Update state
+ state[0] = mx.array(5.0)
+
+ # Prints array(6, dtype=float32)
+ print(fun(mx.array(1.0)))
+
+
+Compiling Training Graphs
+-------------------------
+
+This section will step through how to use :func:`compile` with a simple example
+of a common setup: training a model with :obj:`mlx.nn.Module` using an
+:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
+full forward, backward, and update with :func:`compile`.
+
+To start, here is the simple example without any compilation:
+
+.. code-block:: python
+
+ import mlx.core as mx
+ import mlx.nn as nn
+ import mlx.optimizers as optim
+
+ # 4 examples with 10 features each
+ x = mx.random.uniform(shape=(4, 10))
+
+ # 0, 1 targets
+ y = mx.array([0, 1, 0, 1])
+
+ # Simple linear model
+ model = nn.Linear(10, 1)
+
+ # SGD with momentum
+ optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
+
+ def loss_fn(model, x, y):
+ logits = model(x).squeeze()
+ return nn.losses.binary_cross_entropy(logits, y)
+
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
+
+ # Perform 10 steps of gradient descent
+ for it in range(10):
+ loss, grads = loss_and_grad_fn(model, x, y)
+ optimizer.update(model, grads)
+ mx.eval(model.parameters(), optimizer.state)
+
+To compile the update we can put it all in a function and compile it with the
+appropriate input and output captures. Here's the same example but compiled:
+
+.. code-block:: python
+
+ import mlx.core as mx
+ import mlx.nn as nn
+ import mlx.optimizers as optim
+ from functools import partial
+
+ # 4 examples with 10 features each
+ x = mx.random.uniform(shape=(4, 10))
+
+ # 0, 1 targets
+ y = mx.array([0, 1, 0, 1])
+
+ # Simple linear model
+ model = nn.Linear(10, 1)
+
+ # SGD with momentum
+ optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
+
+ def loss_fn(model, x, y):
+ logits = model(x).squeeze()
+ return nn.losses.binary_cross_entropy(logits, y)
+
+ # The state that will be captured as input and output
+ state = [model.state, optimizer.state]
+
+ @partial(mx.compile, inputs=state, outputs=state)
+ def step(x, y):
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
+ loss, grads = loss_and_grad_fn(model, x, y)
+ optimizer.update(model, grads)
+ return loss
+
+ # Perform 10 steps of gradient descent
+ for it in range(10):
+ loss = step(x, y)
+ # Evaluate the model and optimizer state
+ mx.eval(state)
+ print(loss)
+
+
+.. note::
+
+ If you are using a module which performs random sampling such as
+ :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
+ ``state`` captured by :func:`compile`, i.e. ``state = [model.state,
+ optimizer.state, mx.random.state]``.
+
+
+.. note::
+
+ For more examples of compiling full training graphs checkout the `MLX
+ Examples `_ GitHub repo.
+
+Transformations with Compile
+----------------------------
+
+In MLX function transformations are composable. You can apply any function
+transformation to the output of any other function transformation. For more on
+this, see the documentation on :ref:`function transforms
+`.
+
+Compiling transformed functions works just as expected:
+
+.. code-block:: python
+
+ grad_fn = mx.grad(mx.exp)
+
+ compiled_grad_fn = mx.compile(grad_fn)
+
+ # Prints: array(2.71828, dtype=float32)
+ print(grad_fn(mx.array(1.0)))
+
+ # Also prints: array(2.71828, dtype=float32)
+ print(compiled_grad_fn(mx.array(1.0)))
+
+.. note::
+
+ In order to compile as much as possible, a transformation of a compiled
+ function will not by default be compiled. To compile the transformed
+ function simply pass it through :func:`compile`.
+
+You can also compile functions which themselves call compiled functions. A
+good practice is to compile the outer most function to give :func:`compile`
+the most opportunity to optimize the computation graph:
+
+.. code-block:: python
+
+ @mx.compile
+ def inner(x):
+ return mx.exp(-mx.abs(x))
+
+ def outer(x):
+ inner(inner(x))
+
+ # Compiling the outer function is good to do as it will likely
+ # be faster even though the inner functions are compiled
+ fun = mx.compile(outer)
diff --git a/docs/build/html/_sources/usage/function_transforms.rst b/docs/build/html/_sources/usage/function_transforms.rst
index 72a313f97..02c5dec48 100644
--- a/docs/build/html/_sources/usage/function_transforms.rst
+++ b/docs/build/html/_sources/usage/function_transforms.rst
@@ -5,9 +5,12 @@ Function Transforms
.. currentmodule:: mlx.core
-MLX uses composable function transformations for automatic differentiation and
-vectorization. The key idea behind composable function transformations is that
-every transformation returns a function which can be further transformed.
+MLX uses composable function transformations for automatic differentiation,
+vectorization, and compute graph optimizations. To see the complete list of
+function transformations check-out the :ref:`API documentation `.
+
+The key idea behind composable function transformations is that every
+transformation returns a function which can be further transformed.
Here is a simple example:
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
-depth. To see the complete list of function transformations check-out the
-:ref:`API documentation `. See the following sections for more
-information on :ref:`automatic differentiaion ` and
-:ref:`automatic vectorization `.
+depth. See the following sections for more information on :ref:`automatic
+differentiaion ` and :ref:`automatic vectorization `.
+For more information on :func:`compile` see the :ref:`compile documentation `.
+
Automatic Differentiation
-------------------------
diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js
index 5404195e5..ff2bcd035 100644
--- a/docs/build/html/_static/documentation_options.js
+++ b/docs/build/html/_static/documentation_options.js
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
- VERSION: '0.1.0',
+ VERSION: '0.2.0',
LANGUAGE: 'en',
COLLAPSE_INDEX: false,
BUILDER: 'html',
diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html
index 81be1fd39..fc78b923d 100644
--- a/docs/build/html/cpp/ops.html
+++ b/docs/build/html/cpp/ops.html
@@ -9,7 +9,7 @@
- Operations — MLX 0.1.0 documentation
+ Operations — MLX 0.2.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -153,6 +153,7 @@
We use PyBind11 to build a Python API for the C++ library. Since bindings
-for all needed components such as mlx.core.array, mlx.core.stream, etc.
-are already provided, adding our axpby() becomes very simple!
+
We use PyBind11 to build a Python API for the C++ library. Since bindings for
+components such as mlx.core.array, mlx.core.stream, etc. are
+already provided, adding our axpby() is simple!
PYBIND11_MODULE(mlx_sample_extensions,m){m.doc()="Sample C++ and metal extensions for MLX";
@@ -1552,16 +1563,16 @@ with the naive
We see some modest improvements right away!
-
This operation is now good to be used to build other operations,
-in mlx.nn.Module calls, and also as a part of graph
-transformations like grad()!
+
This operation is now good to be used to build other operations, in
+mlx.nn.Module calls, and also as a part of graph transformations like
+grad()!
Returns a compiled function which produces the same output as fun.
+
+
Parameters:
+
+
fun (function) – A function which takes a variable number of
+array or trees of array and returns
+a variable number of array or trees of array.
+
inputs (list or dict, optional) – These inputs will be captured during
+the function compilation along with the inputs to fun. The inputs
+can be a list or a dict containing arbitrarily nested
+lists, dictionaries, or arrays. Leaf nodes that are not
+array are ignored. Default: None
+
outputs (list or dict, optional) – These outputs will be captured and
+updated in a compiled function. The outputs can be a
+list or a dict containing arbitrarily nested lists,
+dictionaries, or arrays. Leaf nodes that are not array are ignored.
+Default: None
+
+
+
Returns:
+
A compiled function which has the same input arguments
+as fun and returns the the same output(s).
Run a few fast graph simplification operations to reuse computation and
-reduce memory consumption. This function is meant to be run every time
-so its overhead should be small, approximately 1ms for a graph with a
-few thousand nodes.
-
importmlx.coreasmx
-
-deffoo(x):
- y=x@x
- z=x@x
- returny+z
-
-x=mx.ones((10,10))
-y=foo(x)
-z=foo(x)
-
-# Computes the matmul twice
-mx.eval(y)
-
-# Computes the matmul once
-mx.simplify(z)
-mx.eval(z)
-
-
-
-
Parameters:
-
args – Any number of arrays and/or trees of arrays to be simplified.
*args (arrays or trees of arrays) – Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
-list, tuple or dict but the leafs must all be
-an array.
+list, tuple or dict. Leaves which are not
+arrays are ignored.