diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index 9482be725..9aae931a3 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image. const std::vector& argnums) override; /** - * The primitive must know how to vectorize itself accross + * The primitive must know how to vectorize itself across * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. diff --git a/docs/build/html/_sources/examples/mlp.rst b/docs/build/html/_sources/examples/mlp.rst index c003618ce..36890e95c 100644 --- a/docs/build/html/_sources/examples/mlp.rst +++ b/docs/build/html/_sources/examples/mlp.rst @@ -61,7 +61,10 @@ set: def eval_fn(model, X, y): return mx.mean(mx.argmax(model(X), axis=1) == y) -Next, setup the problem parameters and load the data: +Next, setup the problem parameters and load the data. To load the data, you need our +`mnist data loader +`_, which +we will import as `mnist`. .. code-block:: python diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index 682f09f38..92669ab6e 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -15,11 +15,11 @@ To install from PyPI you must meet the following requirements: - Using an M series chip (Apple silicon) - Using a native Python >= 3.8 -- MacOS >= 13.3 +- macOS >= 13.3 .. note:: - MLX is only available on devices running MacOS >= 13.3 - It is highly recommended to use MacOS 14 (Sonoma) + MLX is only available on devices running macOS >= 13.3 + It is highly recommended to use macOS 14 (Sonoma) Troubleshooting ^^^^^^^^^^^^^^^ @@ -35,8 +35,7 @@ Probably you are using a non-native Python. The output of should be ``arm``. If it is ``i386`` (and you have M series machine) then you are using a non-native Python. Switch your Python to a native Python. A good -way to do this is with -`Conda `_. +way to do this is with `Conda `_. Build from source @@ -47,7 +46,7 @@ Build Requirements - A C++ compiler with C++17 support (e.g. Clang >= 5.0) - `cmake `_ -- version 3.24 or later, and ``make`` -- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above) +- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above) Python API @@ -88,6 +87,13 @@ To make sure the install is working run the tests with: pip install ".[testing]" python -m unittest discover python/tests +Optional: Install stubs to enable auto completions and type checking from your IDE: + +.. code-block:: shell + + pip install ".[dev]" + python setup.py generate_stubs + C++ API ^^^^^^^ @@ -154,8 +160,32 @@ should point to the path to the built metal library. export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/" Further, you can use the following command to find out which - MacOS SDK will be used + macOS SDK will be used .. code-block:: shell xcrun -sdk macosx --show-sdk-version + +Troubleshooting +^^^^^^^^^^^^^^^ + +Metal not found +~~~~~~~~~~~~~~~ + +You see the following error when you try to build: + +.. code-block:: shell + + error: unable to find utility "metal", not a developer tool or in PATH + +To fix this, first make sure you have Xcode installed: + +.. code-block:: shell + + xcode-select --install + +Then set the active developer directory: + +.. code-block:: shell + + sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst index 21e66b5e4..a93bbadcd 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst @@ -26,6 +26,7 @@ ~array.cumprod ~array.cumsum ~array.exp + ~array.flatten ~array.item ~array.log ~array.log10 @@ -35,6 +36,7 @@ ~array.max ~array.mean ~array.min + ~array.moveaxis ~array.prod ~array.reciprocal ~array.reshape @@ -45,6 +47,7 @@ ~array.square ~array.squeeze ~array.sum + ~array.swapaxes ~array.tolist ~array.transpose ~array.var diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst new file mode 100644 index 000000000..bbd0a6656 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst @@ -0,0 +1,6 @@ +mlx.core.ceil +============= + +.. currentmodule:: mlx.core + +.. autofunction:: ceil \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst new file mode 100644 index 000000000..90470d914 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst @@ -0,0 +1,6 @@ +mlx.core.flatten +================ + +.. currentmodule:: mlx.core + +.. autofunction:: flatten \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst new file mode 100644 index 000000000..a05f6d451 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst @@ -0,0 +1,6 @@ +mlx.core.floor +============== + +.. currentmodule:: mlx.core + +.. autofunction:: floor \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst new file mode 100644 index 000000000..ed69d670c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst @@ -0,0 +1,6 @@ +mlx.core.moveaxis +================= + +.. currentmodule:: mlx.core + +.. autofunction:: moveaxis \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst new file mode 100644 index 000000000..c0b518497 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst @@ -0,0 +1,6 @@ +mlx.core.simplify +================= + +.. currentmodule:: mlx.core + +.. autofunction:: simplify \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst new file mode 100644 index 000000000..fdb8721a2 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst @@ -0,0 +1,6 @@ +mlx.core.stack +============== + +.. currentmodule:: mlx.core + +.. autofunction:: stack \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst new file mode 100644 index 000000000..07b724a0f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst @@ -0,0 +1,6 @@ +mlx.core.swapaxes +================= + +.. currentmodule:: mlx.core + +.. autofunction:: swapaxes \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst new file mode 100644 index 000000000..ef760035b --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst @@ -0,0 +1,6 @@ +mlx.core.tri +============ + +.. currentmodule:: mlx.core + +.. autofunction:: tri \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst new file mode 100644 index 000000000..89b45b090 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst @@ -0,0 +1,6 @@ +mlx.core.tril +============= + +.. currentmodule:: mlx.core + +.. autofunction:: tril \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst new file mode 100644 index 000000000..1d6aa7626 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst @@ -0,0 +1,6 @@ +mlx.core.triu +============= + +.. currentmodule:: mlx.core + +.. autofunction:: triu \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst new file mode 100644 index 000000000..79f55b253 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst @@ -0,0 +1,58 @@ +mlx.nn.Module +============= + +.. currentmodule:: mlx.nn + +.. autoclass:: Module + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~Module.__init__ + ~Module.apply + ~Module.apply_to_modules + ~Module.children + ~Module.clear + ~Module.copy + ~Module.eval + ~Module.filter_and_map + ~Module.freeze + ~Module.fromkeys + ~Module.get + ~Module.is_module + ~Module.items + ~Module.keys + ~Module.leaf_modules + ~Module.load_weights + ~Module.modules + ~Module.named_modules + ~Module.parameters + ~Module.pop + ~Module.popitem + ~Module.save_weights + ~Module.setdefault + ~Module.train + ~Module.trainable_parameter_filter + ~Module.trainable_parameters + ~Module.unfreeze + ~Module.update + ~Module.valid_child_filter + ~Module.valid_parameter_filter + ~Module.values + + + + + + .. rubric:: Attributes + + .. autosummary:: + + ~Module.training + + \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst new file mode 100644 index 000000000..2ea7cda8a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst @@ -0,0 +1,18 @@ +mlx.optimizers.AdaDelta +======================= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: AdaDelta + + + + + .. rubric:: Methods + + .. autosummary:: + + ~AdaDelta.__init__ + ~AdaDelta.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst new file mode 100644 index 000000000..8a12fc43c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst @@ -0,0 +1,18 @@ +mlx.optimizers.Adagrad +====================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Adagrad + + + + + .. rubric:: Methods + + .. autosummary:: + + ~Adagrad.__init__ + ~Adagrad.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst new file mode 100644 index 000000000..b5259844f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst @@ -0,0 +1,18 @@ +mlx.optimizers.AdamW +==================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: AdamW + + + + + .. rubric:: Methods + + .. autosummary:: + + ~AdamW.__init__ + ~AdamW.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst new file mode 100644 index 000000000..58e6c95ca --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst @@ -0,0 +1,18 @@ +mlx.optimizers.Adamax +===================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Adamax + + + + + .. rubric:: Methods + + .. autosummary:: + + ~Adamax.__init__ + ~Adamax.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst new file mode 100644 index 000000000..217b4619f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst @@ -0,0 +1,18 @@ +mlx.optimizers.RMSprop +====================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: RMSprop + + + + + .. rubric:: Methods + + .. autosummary:: + + ~RMSprop.__init__ + ~RMSprop.apply_single + + diff --git a/docs/build/html/_sources/python/nn.rst b/docs/build/html/_sources/python/nn.rst index 93cfd8c78..bc19a8162 100644 --- a/docs/build/html/_sources/python/nn.rst +++ b/docs/build/html/_sources/python/nn.rst @@ -64,7 +64,6 @@ Quick Start with Neural Networks # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) - .. _module_class: The Module Class @@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other :meth:`Module.parameters` can be used to extract a nested dictionary with all the parameters of a module and its submodules. -A :class:`Module` can also keep track of "frozen" parameters. -:meth:`Module.trainable_parameters` returns only the subset of -:meth:`Module.parameters` that is not frozen. When using -:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these -trainable parameters. +A :class:`Module` can also keep track of "frozen" parameters. See the +:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` +the gradients returned will be with respect to these trainable parameters. -Updating the parameters + +Updating the Parameters ^^^^^^^^^^^^^^^^^^^^^^^ MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is performed by :meth:`Module.update`. -Value and grad + +Inspecting Modules +^^^^^^^^^^^^^^^^^^ + +The simplest way to see the model architecture is to print it. Following along with +the above example, you can print the ``MLP`` with: + +.. code-block:: python + + print(mlp) + +This will display: + +.. code-block:: shell + + MLP( + (layers.0): Linear(input_dims=2, output_dims=128, bias=True) + (layers.1): Linear(input_dims=128, output_dims=128, bias=True) + (layers.2): Linear(input_dims=128, output_dims=10, bias=True) + ) + +To get more detailed information on the arrays in a :class:`Module` you can use +:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of +all the parameters in a :class:`Module` do: + +.. code-block:: python + + from mlx.utils import tree_map + shapes = tree_map(lambda p: p.shape, mlp.parameters()) + +As another example, you can count the number of parameters in a :class:`Module` +with: + +.. code-block:: python + + from mlx.utils import tree_flatten + num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) + + +Value and Grad -------------- Using a :class:`Module` does not preclude using MLX's high order function @@ -133,62 +170,14 @@ In detail: :meth:`mlx.core.value_and_grad` .. autosummary:: + :recursive: :toctree: _autosummary value_and_grad + Module -Neural Network Layers ---------------------- +.. toctree:: -.. autosummary:: - :toctree: _autosummary - :template: nn-module-template.rst - - Embedding - ReLU - PReLU - GELU - SiLU - Step - SELU - Mish - Linear - Conv1d - Conv2d - LayerNorm - RMSNorm - GroupNorm - RoPE - MultiHeadAttention - Sequential - -Layers without parameters (e.g. activation functions) are also provided as -simple functions. - -.. autosummary:: - :toctree: _autosummary_functions - :template: nn-module-template.rst - - gelu - gelu_approx - gelu_fast_approx - relu - prelu - silu - step - selu - mish - -Loss Functions --------------- - -.. autosummary:: - :toctree: _autosummary_functions - :template: nn-module-template.rst - - losses.cross_entropy - losses.binary_cross_entropy - losses.l1_loss - losses.mse_loss - losses.nll_loss - losses.kl_div_loss + nn/layers + nn/functions + nn/losses diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst diff --git a/docs/build/html/_sources/python/nn/functions.rst b/docs/build/html/_sources/python/nn/functions.rst new file mode 100644 index 000000000..f13cbe7b4 --- /dev/null +++ b/docs/build/html/_sources/python/nn/functions.rst @@ -0,0 +1,23 @@ +.. _nn_functions: + +.. currentmodule:: mlx.nn + +Functions +--------- + +Layers without parameters (e.g. activation functions) are also provided as +simple functions. + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + gelu + gelu_approx + gelu_fast_approx + relu + prelu + silu + step + selu + mish diff --git a/docs/build/html/_sources/python/nn/layers.rst b/docs/build/html/_sources/python/nn/layers.rst new file mode 100644 index 000000000..5628134d6 --- /dev/null +++ b/docs/build/html/_sources/python/nn/layers.rst @@ -0,0 +1,28 @@ +.. _layers: + +.. currentmodule:: mlx.nn + +Layers +------ + +.. autosummary:: + :toctree: _autosummary + :template: nn-module-template.rst + + Embedding + ReLU + PReLU + GELU + SiLU + Step + SELU + Mish + Linear + Conv1d + Conv2d + LayerNorm + RMSNorm + GroupNorm + RoPE + MultiHeadAttention + Sequential diff --git a/docs/build/html/_sources/python/nn/losses.rst b/docs/build/html/_sources/python/nn/losses.rst new file mode 100644 index 000000000..4808ce5ab --- /dev/null +++ b/docs/build/html/_sources/python/nn/losses.rst @@ -0,0 +1,17 @@ +.. _losses: + +.. currentmodule:: mlx.nn.losses + +Loss Functions +-------------- + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + cross_entropy + binary_cross_entropy + l1_loss + mse_loss + nll_loss + kl_div_loss diff --git a/docs/build/html/_sources/python/nn/module.rst b/docs/build/html/_sources/python/nn/module.rst deleted file mode 100644 index e14ba96f4..000000000 --- a/docs/build/html/_sources/python/nn/module.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlx.nn.Module -============= - -.. currentmodule:: mlx.nn - -.. autoclass:: Module - :members: diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst index b9a4c9066..ea25b90f9 100644 --- a/docs/build/html/_sources/python/ops.rst +++ b/docs/build/html/_sources/python/ops.rst @@ -26,6 +26,7 @@ Operations argsort array_equal broadcast_to + ceil concatenate convolve conv1d @@ -39,6 +40,8 @@ Operations exp expand_dims eye + floor + flatten full greater greater_equal @@ -59,6 +62,7 @@ Operations mean min minimum + moveaxis multiply negative ones @@ -82,14 +86,19 @@ Operations sqrt square squeeze + stack stop_gradient subtract sum + swapaxes take take_along_axis tan tanh transpose + tri + tril + triu var where zeros diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst index 7f5d3a067..b8e5cfea7 100644 --- a/docs/build/html/_sources/python/optimizers.rst +++ b/docs/build/html/_sources/python/optimizers.rst @@ -38,4 +38,9 @@ model's parameters and the **optimizer state**. OptimizerState Optimizer SGD + RMSprop + Adagrad + AdaDelta Adam + AdamW + Adamax diff --git a/docs/build/html/_sources/python/transforms.rst b/docs/build/html/_sources/python/transforms.rst index cc8d681d5..fa6d1d701 100644 --- a/docs/build/html/_sources/python/transforms.rst +++ b/docs/build/html/_sources/python/transforms.rst @@ -14,3 +14,4 @@ Transforms jvp vjp vmap + simplify diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index cd7666c91..a6e660fe2 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -226,6 +226,7 @@
  • mlx.core.argsort
  • mlx.core.array_equal
  • mlx.core.broadcast_to
  • +
  • mlx.core.ceil
  • mlx.core.concatenate
  • mlx.core.convolve
  • mlx.core.conv1d
  • @@ -239,6 +240,8 @@
  • mlx.core.exp
  • mlx.core.expand_dims
  • mlx.core.eye
  • +
  • mlx.core.floor
  • +
  • mlx.core.flatten
  • mlx.core.full
  • mlx.core.greater
  • mlx.core.greater_equal
  • @@ -259,6 +262,7 @@
  • mlx.core.mean
  • mlx.core.min
  • mlx.core.minimum
  • +
  • mlx.core.moveaxis
  • mlx.core.multiply
  • mlx.core.negative
  • mlx.core.ones
  • @@ -282,14 +286,19 @@
  • mlx.core.sqrt
  • mlx.core.square
  • mlx.core.squeeze
  • +
  • mlx.core.stack
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • +
  • mlx.core.swapaxes
  • mlx.core.take
  • mlx.core.take_along_axis
  • mlx.core.tan
  • mlx.core.tanh
  • mlx.core.transpose
  • +
  • mlx.core.tri
  • +
  • mlx.core.tril
  • +
  • mlx.core.triu
  • mlx.core.var
  • mlx.core.where
  • mlx.core.zeros
  • @@ -316,6 +325,7 @@
  • mlx.core.jvp
  • mlx.core.vjp
  • mlx.core.vmap
  • +
  • mlx.core.simplify
  • FFT