mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
de4f3e72fd
commit
f1dfa257d2
141
docs/build/html/python/nn.html
vendored
141
docs/build/html/python/nn.html
vendored
@@ -9,7 +9,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||||
|
||||
<title>Neural Networks — MLX 0.0.6 documentation</title>
|
||||
<title>Neural Networks — MLX 0.0.7 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="mlx.nn.value_and_grad" href="_autosummary/mlx.nn.value_and_grad.html" />
|
||||
<link rel="prev" title="mlx.core.fft.irfftn" href="_autosummary/mlx.core.fft.irfftn.html" />
|
||||
<link rel="prev" title="mlx.core.linalg.norm" href="_autosummary/mlx.core.linalg.norm.html" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta name="docsearch:language" content="en"/>
|
||||
</head>
|
||||
@@ -134,8 +134,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.6 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.6 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.0.7 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.7 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -278,12 +278,14 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.round.html">mlx.core.round</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.save.html">mlx.core.save</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||||
@@ -303,6 +305,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||||
@@ -351,11 +354,35 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 current active has-children"><a class="current reference internal" href="#">Neural Networks</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="linalg.html">Linear Algebra</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 current active has-children"><a class="current reference internal" href="#">Neural Networks</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.nn.Module.html">mlx.nn.Module</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/module.html">Module</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/layers.html">Layers</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
@@ -363,19 +390,27 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/functions.html">Functions</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||||
@@ -387,7 +422,7 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2 has-children"><a class="reference internal" href="nn/losses.html">Loss Functions</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||||
@@ -396,11 +431,14 @@
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-13" name="toctree-checkbox-13" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-13"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
|
||||
@@ -413,7 +451,7 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l1 has-children"><a class="reference internal" href="tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
|
||||
@@ -671,27 +709,27 @@ for finetuning and more.</p>
|
||||
</section>
|
||||
<section id="the-module-class">
|
||||
<span id="module-class"></span><h2>The Module Class<a class="headerlink" href="#the-module-class" title="Permalink to this heading">#</a></h2>
|
||||
<p>The workhorse of any neural network library is the <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
|
||||
MLX the <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
|
||||
<a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
|
||||
<p>The workhorse of any neural network library is the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class. In
|
||||
MLX the <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> class is a container of <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> or
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances. Its main function is to provide a way to
|
||||
recursively <strong>access</strong> and <strong>update</strong> its parameters and those of its
|
||||
submodules.</p>
|
||||
<section id="parameters">
|
||||
<h3>Parameters<a class="headerlink" href="#parameters" title="Permalink to this heading">#</a></h3>
|
||||
<p>A parameter of a module is any public member of type <a class="reference internal" href="_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> (its
|
||||
name should not start with <code class="docutils literal notranslate"><span class="pre">_</span></code>). It can be arbitrarily nested in other
|
||||
<a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
|
||||
<p><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code> can be used to extract a nested dictionary with all
|
||||
<a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> instances or lists and dictionaries.</p>
|
||||
<p><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.parameters.html#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.parameters()</span></code></a> can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.</p>
|
||||
<p>A <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters. See the
|
||||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.freeze()</span></code> method for more details. <a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a>
|
||||
<p>A <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> can also keep track of “frozen” parameters. See the
|
||||
<a class="reference internal" href="nn/_autosummary/mlx.nn.Module.freeze.html#mlx.nn.Module.freeze" title="mlx.nn.Module.freeze"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.freeze()</span></code></a> method for more details. <a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a>
|
||||
the gradients returned will be with respect to these trainable parameters.</p>
|
||||
</section>
|
||||
<section id="updating-the-parameters">
|
||||
<h3>Updating the Parameters<a class="headerlink" href="#updating-the-parameters" title="Permalink to this heading">#</a></h3>
|
||||
<p>MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module’s parameters. This action is
|
||||
performed by <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code>.</p>
|
||||
performed by <a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>.</p>
|
||||
</section>
|
||||
<section id="inspecting-modules">
|
||||
<h3>Inspecting Modules<a class="headerlink" href="#inspecting-modules" title="Permalink to this heading">#</a></h3>
|
||||
@@ -708,14 +746,14 @@ the above example, you can print the <code class="docutils literal notranslate">
|
||||
<span class="o">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>To get more detailed information on the arrays in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> you can use
|
||||
<p>To get more detailed information on the arrays in a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> you can use
|
||||
<a class="reference internal" href="_autosummary/mlx.utils.tree_map.html#mlx.utils.tree_map" title="mlx.utils.tree_map"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.utils.tree_map()</span></code></a> on the parameters. For example, to see the shapes of
|
||||
all the parameters in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> do:</p>
|
||||
all the parameters in a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> do:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_map</span>
|
||||
<span class="n">shapes</span> <span class="o">=</span> <span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>As another example, you can count the number of parameters in a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a>
|
||||
<p>As another example, you can count the number of parameters in a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a>
|
||||
with:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">mlx.utils</span> <span class="kn">import</span> <span class="n">tree_flatten</span>
|
||||
<span class="n">num_params</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">tree_flatten</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">parameters</span><span class="p">()))</span>
|
||||
@@ -725,7 +763,7 @@ with:</p>
|
||||
</section>
|
||||
<section id="value-and-grad">
|
||||
<h2>Value and Grad<a class="headerlink" href="#value-and-grad" title="Permalink to this heading">#</a></h2>
|
||||
<p>Using a <a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLX’s high order function
|
||||
<p>Using a <a class="reference internal" href="nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a> does not preclude using MLX’s high order function
|
||||
transformations (<a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a>, <a class="reference internal" href="_autosummary/mlx.core.grad.html#mlx.core.grad" title="mlx.core.grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.grad()</span></code></a>, etc.). However,
|
||||
these function transformations assume pure functions, namely the parameters
|
||||
should be passed as an argument to the function being transformed.</p>
|
||||
@@ -743,7 +781,7 @@ should be passed as an argument to the function being transformed.</p>
|
||||
computes the gradients with respect to the trainable parameters of the model.</p>
|
||||
<p>In detail:</p>
|
||||
<ul class="simple">
|
||||
<li><p>it wraps the passed function with a function that calls <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code>
|
||||
<li><p>it wraps the passed function with a function that calls <a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update.html#mlx.nn.Module.update" title="mlx.nn.Module.update"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.update()</span></code></a>
|
||||
to make sure the model is using the provided parameters.</p></li>
|
||||
<li><p>it calls <a class="reference internal" href="_autosummary/mlx.core.value_and_grad.html#mlx.core.value_and_grad" title="mlx.core.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.core.value_and_grad()</span></code></a> to transform the function into a function
|
||||
that also computes the gradients with respect to the passed parameters.</p></li>
|
||||
@@ -756,15 +794,33 @@ parameters as the first argument to the function returned by
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-obj docutils literal notranslate"><span class="pre">value_and_grad</span></code></a>(model, fn)</p></td>
|
||||
<td><p>Transform the passed function <code class="docutils literal notranslate"><span class="pre">fn</span></code> to a function that computes the gradients of <code class="docutils literal notranslate"><span class="pre">fn</span></code> wrt the model's trainable parameters and also its value.</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="_autosummary/mlx.nn.Module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-obj docutils literal notranslate"><span class="pre">Module</span></code></a>()</p></td>
|
||||
<td><p>Base class for building neural networks with MLX.</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<div class="toctree-wrapper compound">
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/module.html">Module</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/layers.html">Layers</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||||
@@ -772,16 +828,24 @@ parameters as the first argument to the function returned by
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="nn/functions.html">Functions</a><ul>
|
||||
@@ -805,6 +869,9 @@ parameters as the first argument to the function returned by
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -824,12 +891,12 @@ parameters as the first argument to the function returned by
|
||||
|
||||
<div class="prev-next-area">
|
||||
<a class="left-prev"
|
||||
href="_autosummary/mlx.core.fft.irfftn.html"
|
||||
href="_autosummary/mlx.core.linalg.norm.html"
|
||||
title="previous page">
|
||||
<i class="fa-solid fa-angle-left"></i>
|
||||
<div class="prev-next-info">
|
||||
<p class="prev-next-subtitle">previous</p>
|
||||
<p class="prev-next-title">mlx.core.fft.irfftn</p>
|
||||
<p class="prev-next-title">mlx.core.linalg.norm</p>
|
||||
</div>
|
||||
</a>
|
||||
<a class="right-next"
|
||||
|
Reference in New Issue
Block a user