mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
docs
This commit is contained in:
committed by
CircleCI Docs
parent
f75712551d
commit
0250e203f6
@@ -26,6 +26,7 @@
|
||||
~array.cumprod
|
||||
~array.cumsum
|
||||
~array.exp
|
||||
~array.flatten
|
||||
~array.item
|
||||
~array.log
|
||||
~array.log10
|
||||
@@ -35,6 +36,7 @@
|
||||
~array.max
|
||||
~array.mean
|
||||
~array.min
|
||||
~array.moveaxis
|
||||
~array.prod
|
||||
~array.reciprocal
|
||||
~array.reshape
|
||||
@@ -45,6 +47,7 @@
|
||||
~array.square
|
||||
~array.squeeze
|
||||
~array.sum
|
||||
~array.swapaxes
|
||||
~array.tolist
|
||||
~array.transpose
|
||||
~array.var
|
||||
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.ceil
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: ceil
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.flatten
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: flatten
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.floor
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: floor
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.moveaxis
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: moveaxis
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.simplify
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: simplify
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.stack
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: stack
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.swapaxes
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: swapaxes
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.tri
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: tri
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.tril
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: tril
|
||||
6
docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.triu
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: triu
|
||||
58
docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst
vendored
Normal file
58
docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
mlx.nn.Module
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Module.__init__
|
||||
~Module.apply
|
||||
~Module.apply_to_modules
|
||||
~Module.children
|
||||
~Module.clear
|
||||
~Module.copy
|
||||
~Module.eval
|
||||
~Module.filter_and_map
|
||||
~Module.freeze
|
||||
~Module.fromkeys
|
||||
~Module.get
|
||||
~Module.is_module
|
||||
~Module.items
|
||||
~Module.keys
|
||||
~Module.leaf_modules
|
||||
~Module.load_weights
|
||||
~Module.modules
|
||||
~Module.named_modules
|
||||
~Module.parameters
|
||||
~Module.pop
|
||||
~Module.popitem
|
||||
~Module.save_weights
|
||||
~Module.setdefault
|
||||
~Module.train
|
||||
~Module.trainable_parameter_filter
|
||||
~Module.trainable_parameters
|
||||
~Module.unfreeze
|
||||
~Module.update
|
||||
~Module.valid_child_filter
|
||||
~Module.valid_parameter_filter
|
||||
~Module.values
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Module.training
|
||||
|
||||
|
||||
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
vendored
Normal file
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
mlx.optimizers.AdaDelta
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: AdaDelta
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~AdaDelta.__init__
|
||||
~AdaDelta.apply_single
|
||||
|
||||
|
||||
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
vendored
Normal file
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
mlx.optimizers.Adagrad
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Adagrad
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Adagrad.__init__
|
||||
~Adagrad.apply_single
|
||||
|
||||
|
||||
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst
vendored
Normal file
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
mlx.optimizers.AdamW
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: AdamW
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~AdamW.__init__
|
||||
~AdamW.apply_single
|
||||
|
||||
|
||||
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
vendored
Normal file
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
mlx.optimizers.Adamax
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Adamax
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Adamax.__init__
|
||||
~Adamax.apply_single
|
||||
|
||||
|
||||
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
vendored
Normal file
18
docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
mlx.optimizers.RMSprop
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: RMSprop
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~RMSprop.__init__
|
||||
~RMSprop.apply_single
|
||||
|
||||
|
||||
113
docs/build/html/_sources/python/nn.rst
vendored
113
docs/build/html/_sources/python/nn.rst
vendored
@@ -64,7 +64,6 @@ Quick Start with Neural Networks
|
||||
# gradient with respect to `mlp.trainable_parameters()`
|
||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||
|
||||
|
||||
.. _module_class:
|
||||
|
||||
The Module Class
|
||||
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
|
||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.
|
||||
|
||||
A :class:`Module` can also keep track of "frozen" parameters.
|
||||
:meth:`Module.trainable_parameters` returns only the subset of
|
||||
:meth:`Module.parameters` that is not frozen. When using
|
||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
||||
trainable parameters.
|
||||
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||
the gradients returned will be with respect to these trainable parameters.
|
||||
|
||||
Updating the parameters
|
||||
|
||||
Updating the Parameters
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module's parameters. This action is
|
||||
performed by :meth:`Module.update`.
|
||||
|
||||
Value and grad
|
||||
|
||||
Inspecting Modules
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The simplest way to see the model architecture is to print it. Following along with
|
||||
the above example, you can print the ``MLP`` with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
print(mlp)
|
||||
|
||||
This will display:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
MLP(
|
||||
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||
)
|
||||
|
||||
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||
all the parameters in a :class:`Module` do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||
|
||||
As another example, you can count the number of parameters in a :class:`Module`
|
||||
with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||
|
||||
|
||||
Value and Grad
|
||||
--------------
|
||||
|
||||
Using a :class:`Module` does not preclude using MLX's high order function
|
||||
@@ -133,62 +170,14 @@ In detail:
|
||||
:meth:`mlx.core.value_and_grad`
|
||||
|
||||
.. autosummary::
|
||||
:recursive:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
Module
|
||||
|
||||
Neural Network Layers
|
||||
---------------------
|
||||
.. toctree::
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
SiLU
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
prelu
|
||||
silu
|
||||
step
|
||||
selu
|
||||
mish
|
||||
|
||||
Loss Functions
|
||||
--------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
losses.cross_entropy
|
||||
losses.binary_cross_entropy
|
||||
losses.l1_loss
|
||||
losses.mse_loss
|
||||
losses.nll_loss
|
||||
losses.kl_div_loss
|
||||
nn/layers
|
||||
nn/functions
|
||||
nn/losses
|
||||
|
||||
23
docs/build/html/_sources/python/nn/functions.rst
vendored
Normal file
23
docs/build/html/_sources/python/nn/functions.rst
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
.. _nn_functions:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
prelu
|
||||
silu
|
||||
step
|
||||
selu
|
||||
mish
|
||||
28
docs/build/html/_sources/python/nn/layers.rst
vendored
Normal file
28
docs/build/html/_sources/python/nn/layers.rst
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
.. _layers:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Layers
|
||||
------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
SiLU
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
17
docs/build/html/_sources/python/nn/losses.rst
vendored
Normal file
17
docs/build/html/_sources/python/nn/losses.rst
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
.. _losses:
|
||||
|
||||
.. currentmodule:: mlx.nn.losses
|
||||
|
||||
Loss Functions
|
||||
--------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
cross_entropy
|
||||
binary_cross_entropy
|
||||
l1_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
kl_div_loss
|
||||
@@ -1,7 +0,0 @@
|
||||
mlx.nn.Module
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
:members:
|
||||
9
docs/build/html/_sources/python/ops.rst
vendored
9
docs/build/html/_sources/python/ops.rst
vendored
@@ -26,6 +26,7 @@ Operations
|
||||
argsort
|
||||
array_equal
|
||||
broadcast_to
|
||||
ceil
|
||||
concatenate
|
||||
convolve
|
||||
conv1d
|
||||
@@ -39,6 +40,8 @@ Operations
|
||||
exp
|
||||
expand_dims
|
||||
eye
|
||||
floor
|
||||
flatten
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
@@ -59,6 +62,7 @@ Operations
|
||||
mean
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
negative
|
||||
ones
|
||||
@@ -82,14 +86,19 @@ Operations
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
swapaxes
|
||||
take
|
||||
take_along_axis
|
||||
tan
|
||||
tanh
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
var
|
||||
where
|
||||
zeros
|
||||
|
||||
@@ -38,4 +38,9 @@ model's parameters and the **optimizer state**.
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
|
||||
@@ -14,3 +14,4 @@ Transforms
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
simplify
|
||||
|
||||
Reference in New Issue
Block a user