mlx/docs/build/html/python/_autosummary/mlx.nn.MultiHeadAttention.html
Awni Hannun fbd10a48d4 docs
2025-06-04 01:01:47 +00:00

198 lines
17 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>mlx.nn.MultiHeadAttention &mdash; MLX 0.0.0 documentation</title>
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../../_static/doctools.js"></script>
<script src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="mlx.nn.Sequential" href="mlx.nn.Sequential.html" />
<link rel="prev" title="mlx.nn.RoPE" href="mlx.nn.RoPE.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home">
MLX
</a>
<div class="version">
0.0.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install.html">Build and Install</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../using_streams.html">Using Streams</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../examples/linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/llama-inference.html">LLM inference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/extensions.html">Developer Documentation</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../array.html">Array</a></li>
<li class="toctree-l1"><a class="reference internal" href="../devices_and_streams.html">Devices and Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../ops.html">Operations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../random.html">Random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../transforms.html">Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fft.html">FFT</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../nn.html">Neural Networks</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../nn.html#quick-start-with-neural-networks">Quick Start with Neural Networks</a></li>
<li class="toctree-l2"><a class="reference internal" href="../nn.html#the-module-class">The Module Class</a></li>
<li class="toctree-l2"><a class="reference internal" href="../nn.html#value-and-grad">Value and grad</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="../nn.html#neural-network-layers">Neural Network Layers</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3 current"><a class="current reference internal" href="#">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="../_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../optimizers.html">Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tree_utils.html">Tree Utils</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../cpp/ops.html">Operations</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">MLX</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item"><a href="../nn.html">Neural Networks</a></li>
<li class="breadcrumb-item active">mlx.nn.MultiHeadAttention</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="mlx-nn-multiheadattention">
<h1>mlx.nn.MultiHeadAttention<a class="headerlink" href="#mlx-nn-multiheadattention" title="Permalink to this heading"></a></h1>
<dl class="py class">
<dt class="sig sig-object py" id="mlx.nn.MultiHeadAttention">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">mlx.nn.</span></span><span class="sig-name descname"><span class="pre">MultiHeadAttention</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_heads</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">query_input_dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key_input_dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value_input_dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value_dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value_output_dims</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Optional" title="(in Python v3.12)"><span class="pre">Optional</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><span class="pre">int</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#mlx.nn.MultiHeadAttention" title="Permalink to this definition"></a></dt>
<dd><p>Implements the scaled dot product attention with multiple heads.</p>
<p>Given inputs for queries, keys and values the <code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code> produces
new values by aggregating information from the input values according to
the similarities of the input queries and keys.</p>
<p>All inputs as well as the output are lineary projected without biases.</p>
<p>MultiHeadAttention also expects an additive attention mask that should be
broadcastable with (batch, num_heads, # queries, # keys). The mask should
have <code class="docutils literal notranslate"><span class="pre">-inf</span></code> or very negative numbers to the positions that should <em>not</em> be
attended to.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a>) The model dimensions. If no other dims are provided then
dims is used for queries, keys, values and the output.</p></li>
<li><p><strong>num_heads</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a>) How many attention heads to use</p></li>
<li><p><strong>query_input_dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a><em>, </em><em>optional</em>) The input dimensions of the queries (default: dims).</p></li>
<li><p><strong>key_input_dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a><em>, </em><em>optional</em>) The input dimensions of the keys (default: dims).</p></li>
<li><p><strong>value_input_dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a><em>, </em><em>optional</em>) The input dimensions of the values (default: key_input_dims).</p></li>
<li><p><strong>value_dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a><em>, </em><em>optional</em>) The dimensions of the values after the projection (default: dims).</p></li>
<li><p><strong>value_output_dims</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.12)"><em>int</em></a><em>, </em><em>optional</em>) The dimensions the new values will be projected to (default: dims).</p></li>
</ul>
</dd>
</dl>
</dd></dl>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="mlx.nn.RoPE.html" class="btn btn-neutral float-left" title="mlx.nn.RoPE" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="mlx.nn.Sequential.html" class="btn btn-neutral float-right" title="mlx.nn.Sequential" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>&#169; Copyright 2023, MLX Contributors.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>