mlx/docs/build/html/python/nn/module.html
Awni Hannun fbd10a48d4 docs
2025-06-04 01:01:47 +00:00

374 lines
40 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>mlx.nn.Module &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../../_static/doctools.js"></script>
<script src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../examples/linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="../devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fft.html">FFT</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.html">Neural Networks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">mlx.nn.Module</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/python/nn/module.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="mlx-nn-module">
<h1>mlx.nn.Module<a class="headerlink" href="#mlx-nn-module" title="Permalink to this heading"></a></h1>
<dl class="py class">
<dt class="sig sig-object py" id="mlx.nn.Module">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">mlx.nn.</span></span><span class="sig-name descname"><span class="pre">Module</span></span><a class="headerlink" href="#mlx.nn.Module" title="Permalink to this definition"></a></dt>
<dd><p>Base class for building neural networks with MLX.</p>
<p>All the layers provided in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn.layers</span></code> subclass this class and
your models should do the same.</p>
<p>A <code class="docutils literal notranslate"><span class="pre">Module</span></code> can contain other <code class="docutils literal notranslate"><span class="pre">Module</span></code> instances or <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>
instances in arbitrary nesting of python lists or dicts. The <code class="docutils literal notranslate"><span class="pre">Module</span></code>
then allows recursively extracting all the <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> instances
using <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.Module.parameters()</span></code></a>.</p>
<p>In addition, the <code class="docutils literal notranslate"><span class="pre">Module</span></code> has the concept of trainable and non trainable
parameters (called “frozen”). When using <a class="reference internal" href="../_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a>
the gradients are returned only with respect to the trainable parameters.
All arrays in a module are trainable unless they are added in the “frozen”
set by calling <a class="reference internal" href="#mlx.nn.Module.freeze" title="mlx.nn.Module.freeze"><code class="xref py py-meth docutils literal notranslate"><span class="pre">freeze()</span></code></a>.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="k">class</span> <span class="nc">MyMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_dims</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dims</span><span class="p">,</span> <span class="n">hidden_dims</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dims</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MyMLP</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># All the model parameters are created but since MLX is lazy by</span>
<span class="c1"># default, they are not evaluated yet. Calling `mx.eval` actually</span>
<span class="c1"># allocates memory and initializes the parameters.</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="c1"># Setting a parameter to a new value is as simply as accessing that</span>
<span class="c1"># parameter and assigning a new array to it.</span>
<span class="n">model</span><span class="o">.</span><span class="n">in_proj</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">in_proj</span><span class="o">.</span><span class="n">weight</span> <span class="o">*</span> <span class="mi">2</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
</pre></div>
</div>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.apply">
<span class="sig-name descname"><span class="pre">apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">map_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><span class="pre">array</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><span class="pre">array</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">filter_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.apply" title="Permalink to this definition"></a></dt>
<dd><p>Map all the parameters using the provided <code class="docutils literal notranslate"><span class="pre">map_fn</span></code> and immediately
update the module with the mapped parameters.</p>
<p>For instance running <code class="docutils literal notranslate"><span class="pre">model.apply(lambda</span> <span class="pre">x:</span> <span class="pre">x.astype(mx.float16))</span></code>
casts all parameters to 16 bit floats.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>map_fn</strong> (<em>Callable</em>) Maps an array to another array</p></li>
<li><p><strong>filter_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) Filter to select which arrays to
map (default: <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.valid_parameter_filter()</span></code>).</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.apply_to_modules">
<span class="sig-name descname"><span class="pre">apply_to_modules</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">apply_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.apply_to_modules" title="Permalink to this definition"></a></dt>
<dd><p>Apply a function to all the modules in this instance (including this
instance).</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><p><strong>apply_fn</strong> (<em>Callable</em>) The function to apply to the modules.</p>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.children">
<span class="sig-name descname"><span class="pre">children</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.children" title="Permalink to this definition"></a></dt>
<dd><p>Return the direct descendants of this Module instance.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.filter_and_map">
<span class="sig-name descname"><span class="pre">filter_and_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">filter_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">map_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">is_leaf_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.filter_and_map" title="Permalink to this definition"></a></dt>
<dd><p>Recursively filter the contents of the module using <code class="docutils literal notranslate"><span class="pre">filter_fn</span></code>,
namely only select keys and values where <code class="docutils literal notranslate"><span class="pre">filter_fn</span></code> returns true.</p>
<p>This is used to implement <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">parameters()</span></code></a> and <a class="reference internal" href="#mlx.nn.Module.trainable_parameters" title="mlx.nn.Module.trainable_parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">trainable_parameters()</span></code></a>
but it can also be used to extract any subset of the modules parameters.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>filter_fn</strong> (<em>Callable</em>) Given a value, the key in which it is found
and the containing module, decide whether to keep the value or
drop it.</p></li>
<li><p><strong>map_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) Optionally transform the value before
returning it.</p></li>
<li><p><strong>is_leaf_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) Given a value, the key in which it
is found and the containing module decide if it is a leaf.</p></li>
</ul>
</dd>
<dt class="field-even">Returns<span class="colon">:</span></dt>
<dd class="field-even"><p>A dictionary containing the contents of the module recursively filtered</p>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.freeze">
<span class="sig-name descname"><span class="pre">freeze</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurse</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Union" title="(in Python v3.12)"><span class="pre">Union</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.List" title="(in Python v3.12)"><span class="pre">List</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strict</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.freeze" title="Permalink to this definition"></a></dt>
<dd><p>Freeze the Modules parameters or some of them. Freezing a parameter means not
computing gradients for it.</p>
<p>This function is idempotent ie freezing a frozen model is a noop.</p>
<p>For instance to only train the attention parameters from a transformer:</p>
<blockquote>
<div><p>model = …
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith(“attention”) else None)</p>
</div></blockquote>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>recurse</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) If True then freeze the parameters of the
submodules as well (default: True).</p></li>
<li><p><strong>keys</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em> or </em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><em>list</em></a><em>[</em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
<code class="docutils literal notranslate"><span class="pre">module.freeze(keys=&quot;bias&quot;)</span></code>.</p></li>
<li><p><strong>strict</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) If set to True validate that the passed keys exist
(default: False).</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.leaf_modules">
<span class="sig-name descname"><span class="pre">leaf_modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.leaf_modules" title="Permalink to this definition"></a></dt>
<dd><p>Return the submodules that do not contain other modules.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.load_weights">
<span class="sig-name descname"><span class="pre">load_weights</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">file</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.load_weights" title="Permalink to this definition"></a></dt>
<dd><p>Load and update the models weights from a <cite>.npz</cite> file.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.modules">
<span class="sig-name descname"><span class="pre">modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.modules" title="Permalink to this definition"></a></dt>
<dd><p>Return a list with all the modules in this instance.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns<span class="colon">:</span></dt>
<dd class="field-odd"><p>A list of <a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> instances.</p>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.named_modules">
<span class="sig-name descname"><span class="pre">named_modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.named_modules" title="Permalink to this definition"></a></dt>
<dd><p>Return a list with all the modules in this instance and their name
with dot notation.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns<span class="colon">:</span></dt>
<dd class="field-odd"><p>A list of tuples (str, <a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a>).</p>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.parameters">
<span class="sig-name descname"><span class="pre">parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.parameters" title="Permalink to this definition"></a></dt>
<dd><p>Recursively return all the <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> members of this Module
as a dict of dicts and lists.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.save_weights">
<span class="sig-name descname"><span class="pre">save_weights</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">file</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.save_weights" title="Permalink to this definition"></a></dt>
<dd><p>Save the models weights to a <cite>.npz</cite> file.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.trainable_parameters">
<span class="sig-name descname"><span class="pre">trainable_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.trainable_parameters" title="Permalink to this definition"></a></dt>
<dd><p>Recursively return all the non frozen <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> members of
this Module as a dict of dicts and lists.</p>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.unfreeze">
<span class="sig-name descname"><span class="pre">unfreeze</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurse</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Union" title="(in Python v3.12)"><span class="pre">Union</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.List" title="(in Python v3.12)"><span class="pre">List</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strict</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.unfreeze" title="Permalink to this definition"></a></dt>
<dd><p>Unfreeze the Modules parameters or some of them.</p>
<p>This function is idempotent ie unfreezing a model that is not frozen is
a noop.</p>
<p>For instance to only train the biases one can do:</p>
<blockquote>
<div><p>model = …
model.freeze()
model.unfreeze(keys=”bias”)</p>
</div></blockquote>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>recurse</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) If True then unfreeze the parameters of the
submodules as well (default: True).</p></li>
<li><p><strong>keys</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em> or </em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><em>list</em></a><em>[</em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
<code class="docutils literal notranslate"><span class="pre">module.unfreeze(keys=&quot;bias&quot;)</span></code>.</p></li>
<li><p><strong>strict</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) If set to True validate that the passed keys exist
(default: False).</p></li>
</ul>
</dd>
</dl>
</dd></dl>
<dl class="py method">
<dt class="sig sig-object py" id="mlx.nn.Module.update">
<span class="sig-name descname"><span class="pre">update</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">parameters</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#dict" title="(in Python v3.12)"><span class="pre">dict</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.update" title="Permalink to this definition"></a></dt>
<dd><p>Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.</p>
<p>Commonly used by the optimizer to change the model to the updated
(optimized) parameters. Also used by the <a class="reference internal" href="../_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> to set the
tracers in the model in order to compute gradients.</p>
<p>The passed in parameters dictionary need not be a full dictionary
similar to <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">parameters()</span></code></a>. Only the provided locations will be
updated.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><p><strong>parameters</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#dict" title="(in Python v3.12)"><em>dict</em></a>) A complete or partial dictionary of the modules
parameters.</p>
</dd>
</dl>
</dd></dl>
</dd></dl>
</section>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>