mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
374 lines
40 KiB
HTML
374 lines
40 KiB
HTML
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>mlx.nn.Module — MLX 0.0.0 documentation</title>
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
<script src="../../_static/js/theme.js"></script>
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
<div class="wy-grid-for-nav">
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home">
|
||
MLX
|
||
</a>
|
||
<div class="version">
|
||
0.0.0
|
||
</div>
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
|
||
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Build and Install</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../using_streams.html">Using Streams</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../examples/linear_regression.html">Linear Regression</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../examples/mlp.html">Multi-Layer Perceptron</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../examples/llama-inference.html">LLM inference</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../dev/extensions.html">Developer Documentation</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../array.html">Array</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../devices_and_streams.html">Devices and Streams</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../ops.html">Operations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../random.html">Random</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../transforms.html">Transforms</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../fft.html">FFT</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../nn.html">Neural Networks</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../optimizers.html">Optimizers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tree_utils.html">Tree Utils</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../cpp/ops.html">Operations</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">MLX</a>
|
||
</nav>
|
||
|
||
<div class="wy-nav-content">
|
||
<div class="rst-content">
|
||
<div role="navigation" aria-label="Page navigation">
|
||
<ul class="wy-breadcrumbs">
|
||
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
|
||
<li class="breadcrumb-item active">mlx.nn.Module</li>
|
||
<li class="wy-breadcrumbs-aside">
|
||
<a href="../../_sources/python/nn/module.rst.txt" rel="nofollow"> View page source</a>
|
||
</li>
|
||
</ul>
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<section id="mlx-nn-module">
|
||
<h1>mlx.nn.Module<a class="headerlink" href="#mlx-nn-module" title="Permalink to this heading"></a></h1>
|
||
<dl class="py class">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module">
|
||
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">mlx.nn.</span></span><span class="sig-name descname"><span class="pre">Module</span></span><a class="headerlink" href="#mlx.nn.Module" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Base class for building neural networks with MLX.</p>
|
||
<p>All the layers provided in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn.layers</span></code> subclass this class and
|
||
your models should do the same.</p>
|
||
<p>A <code class="docutils literal notranslate"><span class="pre">Module</span></code> can contain other <code class="docutils literal notranslate"><span class="pre">Module</span></code> instances or <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>
|
||
instances in arbitrary nesting of python lists or dicts. The <code class="docutils literal notranslate"><span class="pre">Module</span></code>
|
||
then allows recursively extracting all the <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> instances
|
||
using <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.Module.parameters()</span></code></a>.</p>
|
||
<p>In addition, the <code class="docutils literal notranslate"><span class="pre">Module</span></code> has the concept of trainable and non trainable
|
||
parameters (called “frozen”). When using <a class="reference internal" href="../_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a>
|
||
the gradients are returned only with respect to the trainable parameters.
|
||
All arrays in a module are trainable unless they are added in the “frozen”
|
||
set by calling <a class="reference internal" href="#mlx.nn.Module.freeze" title="mlx.nn.Module.freeze"><code class="xref py py-meth docutils literal notranslate"><span class="pre">freeze()</span></code></a>.</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mlx.core</span> <span class="k">as</span> <span class="nn">mx</span>
|
||
<span class="kn">import</span> <span class="nn">mlx.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||
|
||
<span class="k">class</span> <span class="nc">MyMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_dims</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">in_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dims</span><span class="p">,</span> <span class="n">hidden_dims</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dims</span><span class="p">,</span> <span class="n">out_dims</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="n">model</span> <span class="o">=</span> <span class="n">MyMLP</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="c1"># All the model parameters are created but since MLX is lazy by</span>
|
||
<span class="c1"># default, they are not evaluated yet. Calling `mx.eval` actually</span>
|
||
<span class="c1"># allocates memory and initializes the parameters.</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||
|
||
<span class="c1"># Setting a parameter to a new value is as simply as accessing that</span>
|
||
<span class="c1"># parameter and assigning a new array to it.</span>
|
||
<span class="n">model</span><span class="o">.</span><span class="n">in_proj</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">in_proj</span><span class="o">.</span><span class="n">weight</span> <span class="o">*</span> <span class="mi">2</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
||
</pre></div>
|
||
</div>
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.apply">
|
||
<span class="sig-name descname"><span class="pre">apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">map_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><span class="pre">array</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><span class="pre">array</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">filter_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.apply" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Map all the parameters using the provided <code class="docutils literal notranslate"><span class="pre">map_fn</span></code> and immediately
|
||
update the module with the mapped parameters.</p>
|
||
<p>For instance running <code class="docutils literal notranslate"><span class="pre">model.apply(lambda</span> <span class="pre">x:</span> <span class="pre">x.astype(mx.float16))</span></code>
|
||
casts all parameters to 16 bit floats.</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><ul class="simple">
|
||
<li><p><strong>map_fn</strong> (<em>Callable</em>) – Maps an array to another array</p></li>
|
||
<li><p><strong>filter_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) – Filter to select which arrays to
|
||
map (default: <code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.valid_parameter_filter()</span></code>).</p></li>
|
||
</ul>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.apply_to_modules">
|
||
<span class="sig-name descname"><span class="pre">apply_to_modules</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">apply_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.apply_to_modules" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Apply a function to all the modules in this instance (including this
|
||
instance).</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><p><strong>apply_fn</strong> (<em>Callable</em>) – The function to apply to the modules.</p>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.children">
|
||
<span class="sig-name descname"><span class="pre">children</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.children" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Return the direct descendants of this Module instance.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.filter_and_map">
|
||
<span class="sig-name descname"><span class="pre">filter_and_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">filter_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">map_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">is_leaf_fn</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Callable" title="(in Python v3.12)"><span class="pre">Callable</span></a><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><span class="pre">Module</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.12)"><span class="pre">Any</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.filter_and_map" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Recursively filter the contents of the module using <code class="docutils literal notranslate"><span class="pre">filter_fn</span></code>,
|
||
namely only select keys and values where <code class="docutils literal notranslate"><span class="pre">filter_fn</span></code> returns true.</p>
|
||
<p>This is used to implement <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">parameters()</span></code></a> and <a class="reference internal" href="#mlx.nn.Module.trainable_parameters" title="mlx.nn.Module.trainable_parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">trainable_parameters()</span></code></a>
|
||
but it can also be used to extract any subset of the module’s parameters.</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><ul class="simple">
|
||
<li><p><strong>filter_fn</strong> (<em>Callable</em>) – Given a value, the key in which it is found
|
||
and the containing module, decide whether to keep the value or
|
||
drop it.</p></li>
|
||
<li><p><strong>map_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) – Optionally transform the value before
|
||
returning it.</p></li>
|
||
<li><p><strong>is_leaf_fn</strong> (<em>Callable</em><em>, </em><em>optional</em>) – Given a value, the key in which it
|
||
is found and the containing module decide if it is a leaf.</p></li>
|
||
</ul>
|
||
</dd>
|
||
<dt class="field-even">Returns<span class="colon">:</span></dt>
|
||
<dd class="field-even"><p>A dictionary containing the contents of the module recursively filtered</p>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.freeze">
|
||
<span class="sig-name descname"><span class="pre">freeze</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurse</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Union" title="(in Python v3.12)"><span class="pre">Union</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.List" title="(in Python v3.12)"><span class="pre">List</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strict</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.freeze" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Freeze the Module’s parameters or some of them. Freezing a parameter means not
|
||
computing gradients for it.</p>
|
||
<p>This function is idempotent ie freezing a frozen model is a noop.</p>
|
||
<p>For instance to only train the attention parameters from a transformer:</p>
|
||
<blockquote>
|
||
<div><p>model = …
|
||
model.freeze()
|
||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith(“attention”) else None)</p>
|
||
</div></blockquote>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><ul class="simple">
|
||
<li><p><strong>recurse</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) – If True then freeze the parameters of the
|
||
submodules as well (default: True).</p></li>
|
||
<li><p><strong>keys</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em> or </em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><em>list</em></a><em>[</em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) – If provided then only these
|
||
parameters will be frozen otherwise all the parameters of a
|
||
module. For instance freeze all biases by calling
|
||
<code class="docutils literal notranslate"><span class="pre">module.freeze(keys="bias")</span></code>.</p></li>
|
||
<li><p><strong>strict</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) – If set to True validate that the passed keys exist
|
||
(default: False).</p></li>
|
||
</ul>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.leaf_modules">
|
||
<span class="sig-name descname"><span class="pre">leaf_modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.leaf_modules" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Return the submodules that do not contain other modules.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.load_weights">
|
||
<span class="sig-name descname"><span class="pre">load_weights</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">file</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.load_weights" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Load and update the model’s weights from a <cite>.npz</cite> file.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.modules">
|
||
<span class="sig-name descname"><span class="pre">modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.modules" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Return a list with all the modules in this instance.</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Returns<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><p>A list of <a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> instances.</p>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.named_modules">
|
||
<span class="sig-name descname"><span class="pre">named_modules</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.named_modules" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Return a list with all the modules in this instance and their name
|
||
with dot notation.</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Returns<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><p>A list of tuples (str, <a class="reference internal" href="#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a>).</p>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.parameters">
|
||
<span class="sig-name descname"><span class="pre">parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.parameters" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Recursively return all the <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> members of this Module
|
||
as a dict of dicts and lists.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.save_weights">
|
||
<span class="sig-name descname"><span class="pre">save_weights</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">file</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.save_weights" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Save the model’s weights to a <cite>.npz</cite> file.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.trainable_parameters">
|
||
<span class="sig-name descname"><span class="pre">trainable_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.trainable_parameters" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Recursively return all the non frozen <a class="reference internal" href="../_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> members of
|
||
this Module as a dict of dicts and lists.</p>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.unfreeze">
|
||
<span class="sig-name descname"><span class="pre">unfreeze</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurse</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Union" title="(in Python v3.12)"><span class="pre">Union</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">,</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.List" title="(in Python v3.12)"><span class="pre">List</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strict</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.unfreeze" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Unfreeze the Module’s parameters or some of them.</p>
|
||
<p>This function is idempotent ie unfreezing a model that is not frozen is
|
||
a noop.</p>
|
||
<p>For instance to only train the biases one can do:</p>
|
||
<blockquote>
|
||
<div><p>model = …
|
||
model.freeze()
|
||
model.unfreeze(keys=”bias”)</p>
|
||
</div></blockquote>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><ul class="simple">
|
||
<li><p><strong>recurse</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) – If True then unfreeze the parameters of the
|
||
submodules as well (default: True).</p></li>
|
||
<li><p><strong>keys</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em> or </em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><em>list</em></a><em>[</em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.12)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) – If provided then only these
|
||
parameters will be unfrozen otherwise all the parameters of a
|
||
module. For instance unfreeze all biases by calling
|
||
<code class="docutils literal notranslate"><span class="pre">module.unfreeze(keys="bias")</span></code>.</p></li>
|
||
<li><p><strong>strict</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.12)"><em>bool</em></a><em>, </em><em>optional</em>) – If set to True validate that the passed keys exist
|
||
(default: False).</p></li>
|
||
</ul>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
<dl class="py method">
|
||
<dt class="sig sig-object py" id="mlx.nn.Module.update">
|
||
<span class="sig-name descname"><span class="pre">update</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">parameters</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#dict" title="(in Python v3.12)"><span class="pre">dict</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.Module.update" title="Permalink to this definition"></a></dt>
|
||
<dd><p>Replace the parameters of this Module with the provided ones in the
|
||
dict of dicts and lists.</p>
|
||
<p>Commonly used by the optimizer to change the model to the updated
|
||
(optimized) parameters. Also used by the <a class="reference internal" href="../_autosummary/mlx.nn.value_and_grad.html#mlx.nn.value_and_grad" title="mlx.nn.value_and_grad"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.nn.value_and_grad()</span></code></a> to set the
|
||
tracers in the model in order to compute gradients.</p>
|
||
<p>The passed in parameters dictionary need not be a full dictionary
|
||
similar to <a class="reference internal" href="#mlx.nn.Module.parameters" title="mlx.nn.Module.parameters"><code class="xref py py-meth docutils literal notranslate"><span class="pre">parameters()</span></code></a>. Only the provided locations will be
|
||
updated.</p>
|
||
<dl class="field-list simple">
|
||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||
<dd class="field-odd"><p><strong>parameters</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#dict" title="(in Python v3.12)"><em>dict</em></a>) – A complete or partial dictionary of the modules
|
||
parameters.</p>
|
||
</dd>
|
||
</dl>
|
||
</dd></dl>
|
||
|
||
</dd></dl>
|
||
|
||
</section>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
<footer>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>© Copyright 2023, MLX Contributors.</p>
|
||
</div>
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</div>
|
||
<script>
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
</body>
|
||
</html> |