mlx/docs/build/html/examples/llama-inference.html
CircleCI Docs 0b38729f41 rebase
2025-06-04 01:03:47 +00:00

1404 lines
160 KiB
HTML

<!DOCTYPE html>
<html lang="en" data-content_root="../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>LLM inference &#8212; MLX 0.26.1 documentation</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=3724ff34"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
<script>DOCUMENTATION_OPTIONS.pagename = 'examples/llama-inference';</script>
<link rel="icon" href="../_static/mlx_logo.png"/>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Array" href="../python/array.html" />
<link rel="prev" title="Multi-Layer Perceptron" href="mlp.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<input type="checkbox"
class="sidebar-toggle"
id="pst-primary-sidebar-checkbox"/>
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
<input type="checkbox"
class="sidebar-toggle"
id="pst-secondary-sidebar-checkbox"/>
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
<div class="search-button__wrapper">
<div class="search-button__overlay"></div>
<div class="search-button__search-container">
<form class="bd-search d-flex align-items-center"
action="../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
id="search-input"
placeholder="Search..."
aria-label="Search..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form></div>
</div>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<div class="bd-sidebar-primary bd-sidebar">
<div class="sidebar-header-items sidebar-primary__section">
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<a class="navbar-brand logo" href="../index.html">
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.1 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/>`);</script>
</a></div>
<div class="sidebar-primary-item">
<script>
document.write(`
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
`);
</script></div>
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
<div class="bd-toc-item navbar-nav active">
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../usage/quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/lazy_evaluation.html">Lazy Evaluation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/unified_memory.html">Unified Memory</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/function_transforms.html">Function Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/compile.html">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="current nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1 current active"><a class="current reference internal" href="#">LLM inference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/array.html">Array</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.html">mlx.core.array</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.astype.html">mlx.core.array.astype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.at.html">mlx.core.array.at</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html">mlx.core.array.item</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.tolist.html">mlx.core.array.tolist</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.dtype.html">mlx.core.array.dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.itemsize.html">mlx.core.array.itemsize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.nbytes.html">mlx.core.array.nbytes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.ndim.html">mlx.core.array.ndim</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.shape.html">mlx.core.array.shape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.size.html">mlx.core.array.size</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.real.html">mlx.core.array.real</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.imag.html">mlx.core.array.imag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.abs.html">mlx.core.array.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.all.html">mlx.core.array.all</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.any.html">mlx.core.array.any</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmax.html">mlx.core.array.argmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmin.html">mlx.core.array.argmin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.conj.html">mlx.core.array.conj</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cos.html">mlx.core.array.cos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummax.html">mlx.core.array.cummax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummin.html">mlx.core.array.cummin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumprod.html">mlx.core.array.cumprod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumsum.html">mlx.core.array.cumsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diag.html">mlx.core.array.diag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diagonal.html">mlx.core.array.diagonal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.exp.html">mlx.core.array.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.flatten.html">mlx.core.array.flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log.html">mlx.core.array.log</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log10.html">mlx.core.array.log10</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log1p.html">mlx.core.array.log1p</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log2.html">mlx.core.array.log2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logcumsumexp.html">mlx.core.array.logcumsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logsumexp.html">mlx.core.array.logsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.max.html">mlx.core.array.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.mean.html">mlx.core.array.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.min.html">mlx.core.array.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.moveaxis.html">mlx.core.array.moveaxis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.prod.html">mlx.core.array.prod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reciprocal.html">mlx.core.array.reciprocal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reshape.html">mlx.core.array.reshape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.round.html">mlx.core.array.round</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.rsqrt.html">mlx.core.array.rsqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sin.html">mlx.core.array.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.split.html">mlx.core.array.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sqrt.html">mlx.core.array.sqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.square.html">mlx.core.array.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.squeeze.html">mlx.core.array.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.std.html">mlx.core.array.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sum.html">mlx.core.array.sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.swapaxes.html">mlx.core.array.swapaxes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.transpose.html">mlx.core.array.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.T.html">mlx.core.array.T</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.var.html">mlx.core.array.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.view.html">mlx.core.array.view</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/data_types.html">Data Types</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Device.html">mlx.core.Device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/stream_class.html">mlx.core.Stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_device.html">mlx.core.default_device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_device.html">mlx.core.set_default_device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_stream.html">mlx.core.default_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.new_stream.html">mlx.core.new_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_stream.html">mlx.core.set_default_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stream.html">mlx.core.stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.addmm.html">mlx.core.addmm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.all.html">mlx.core.all</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.allclose.html">mlx.core.allclose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.any.html">mlx.core.any</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arange.html">mlx.core.arange</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccos.html">mlx.core.arccos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccosh.html">mlx.core.arccosh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsin.html">mlx.core.arcsin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsinh.html">mlx.core.arcsinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan.html">mlx.core.arctan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan2.html">mlx.core.arctan2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctanh.html">mlx.core.arctanh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmax.html">mlx.core.argmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmin.html">mlx.core.argmin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argpartition.html">mlx.core.argpartition</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argsort.html">mlx.core.argsort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array_equal.html">mlx.core.array_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.as_strided.html">mlx.core.as_strided</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_1d.html">mlx.core.atleast_1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_2d.html">mlx.core.atleast_2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_3d.html">mlx.core.atleast_3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_and.html">mlx.core.bitwise_and</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_invert.html">mlx.core.bitwise_invert</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_or.html">mlx.core.bitwise_or</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_xor.html">mlx.core.bitwise_xor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.block_masked_mm.html">mlx.core.block_masked_mm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_arrays.html">mlx.core.broadcast_arrays</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clip.html">mlx.core.clip</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.contiguous.html">mlx.core.contiguous</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conj.html">mlx.core.conj</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conjugate.html">mlx.core.conjugate</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv1d.html">mlx.core.conv1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv2d.html">mlx.core.conv2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv3d.html">mlx.core.conv3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose1d.html">mlx.core.conv_transpose1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose2d.html">mlx.core.conv_transpose2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose3d.html">mlx.core.conv_transpose3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_general.html">mlx.core.conv_general</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cos.html">mlx.core.cos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cosh.html">mlx.core.cosh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummax.html">mlx.core.cummax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummin.html">mlx.core.cummin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumprod.html">mlx.core.cumprod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumsum.html">mlx.core.cumsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.degrees.html">mlx.core.degrees</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.dequantize.html">mlx.core.dequantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diag.html">mlx.core.diag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diagonal.html">mlx.core.diagonal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divide.html">mlx.core.divide</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divmod.html">mlx.core.divmod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum.html">mlx.core.einsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum_path.html">mlx.core.einsum_path</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.equal.html">mlx.core.equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor.html">mlx.core.floor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor_divide.html">mlx.core.floor_divide</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.full.html">mlx.core.full</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_mm.html">mlx.core.gather_mm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_qmm.html">mlx.core.gather_qmm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater.html">mlx.core.greater</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.hadamard_transform.html">mlx.core.hadamard_transform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.imag.html">mlx.core.imag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isfinite.html">mlx.core.isfinite</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isclose.html">mlx.core.isclose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isinf.html">mlx.core.isinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isnan.html">mlx.core.isnan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linspace.html">mlx.core.linspace</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.load.html">mlx.core.load</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log.html">mlx.core.log</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log2.html">mlx.core.log2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log10.html">mlx.core.log10</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log1p.html">mlx.core.log1p</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logaddexp.html">mlx.core.logaddexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logcumsumexp.html">mlx.core.logcumsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_not.html">mlx.core.logical_not</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_and.html">mlx.core.logical_and</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_or.html">mlx.core.logical_or</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logsumexp.html">mlx.core.logsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.matmul.html">mlx.core.matmul</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.multiply.html">mlx.core.multiply</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.nan_to_num.html">mlx.core.nan_to_num</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.negative.html">mlx.core.negative</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.not_equal.html">mlx.core.not_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones.html">mlx.core.ones</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones_like.html">mlx.core.ones_like</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.outer.html">mlx.core.outer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.partition.html">mlx.core.partition</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.pad.html">mlx.core.pad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.power.html">mlx.core.power</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.prod.html">mlx.core.prod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.put_along_axis.html">mlx.core.put_along_axis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.radians.html">mlx.core.radians</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.real.html">mlx.core.real</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.remainder.html">mlx.core.remainder</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.right_shift.html">mlx.core.right_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.roll.html">mlx.core.roll</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.round.html">mlx.core.round</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_gguf.html">mlx.core.save_gguf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sqrt.html">mlx.core.sqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.swapaxes.html">mlx.core.swapaxes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take.html">mlx.core.take</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tile.html">mlx.core.tile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.topk.html">mlx.core.topk</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.trace.html">mlx.core.trace</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros.html">mlx.core.zeros</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros_like.html">mlx.core.zeros_like</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/random.html">Random</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.bernoulli.html">mlx.core.random.bernoulli</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.categorical.html">mlx.core.random.categorical</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.truncated_normal.html">mlx.core.random.truncated_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.uniform.html">mlx.core.random.uniform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.laplace.html">mlx.core.random.laplace</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.permutation.html">mlx.core.random.permutation</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.async_eval.html">mlx.core.async_eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html">mlx.core.custom_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.enable_compile.html">mlx.core.enable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.grad.html">mlx.core.grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html">mlx.core.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vjp.html">mlx.core.vjp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html">mlx.core.vmap</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fast.html">Fast</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rms_norm.html">mlx.core.fast.rms_norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.layer_norm.html">mlx.core.fast.layer_norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rope.html">mlx.core.fast.rope</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html">mlx.core.fast.scaled_dot_product_attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html">mlx.core.fast.metal_kernel</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fft.html">FFT</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft.html">mlx.core.fft.fft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft.html">mlx.core.fft.ifft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft2.html">mlx.core.fft.fft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft2.html">mlx.core.fft.ifft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftn.html">mlx.core.fft.fftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftn.html">mlx.core.fft.ifftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft.html">mlx.core.fft.rfft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft.html">mlx.core.fft.irfft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft2.html">mlx.core.fft.rfft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft2.html">mlx.core.fft.irfft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfftn.html">mlx.core.fft.rfftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftshift.html">mlx.core.fft.fftshift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftshift.html">mlx.core.fft.ifftshift</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/linalg.html">Linear Algebra</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.inv.html">mlx.core.linalg.inv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.tri_inv.html">mlx.core.linalg.tri_inv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky.html">mlx.core.linalg.cholesky</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky_inv.html">mlx.core.linalg.cholesky_inv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cross.html">mlx.core.linalg.cross</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.qr.html">mlx.core.linalg.qr</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.svd.html">mlx.core.linalg.svd</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvals.html">mlx.core.linalg.eigvals</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eig.html">mlx.core.linalg.eig</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvalsh.html">mlx.core.linalg.eigvalsh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigh.html">mlx.core.linalg.eigh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu.html">mlx.core.linalg.lu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu_factor.html">mlx.core.linalg.lu_factor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.pinv.html">mlx.core.linalg.pinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve.html">mlx.core.linalg.solve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve_triangular.html">mlx.core.linalg.solve_triangular</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/metal.html">Metal</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.is_available.html">mlx.core.metal.is_available</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.device_info.html">mlx.core.metal.device_info</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/memory_management.html">Memory Management</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_active_memory.html">mlx.core.get_active_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_peak_memory.html">mlx.core.get_peak_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reset_peak_memory.html">mlx.core.reset_peak_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_cache_memory.html">mlx.core.get_cache_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_memory_limit.html">mlx.core.set_memory_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_cache_limit.html">mlx.core.set_cache_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_wired_limit.html">mlx.core.set_wired_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clear_cache.html">mlx.core.clear_cache</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.quantize.html">mlx.nn.quantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.average_gradients.html">mlx.nn.average_gradients</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.set_dtype.html">mlx.nn.Module.set_dtype</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool1d.html">mlx.nn.AvgPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool2d.html">mlx.nn.AvgPool2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool3d.html">mlx.nn.AvgPool3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.CELU.html">mlx.nn.CELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv3d.html">mlx.nn.Conv3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html">mlx.nn.ConvTranspose1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html">mlx.nn.ConvTranspose2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html">mlx.nn.ConvTranspose3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ELU.html">mlx.nn.ELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GLU.html">mlx.nn.GLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GRU.html">mlx.nn.GRU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardShrink.html">mlx.nn.HardShrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardTanh.html">mlx.nn.HardTanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Hardswish.html">mlx.nn.Hardswish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LeakyReLU.html">mlx.nn.LeakyReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSigmoid.html">mlx.nn.LogSigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSoftmax.html">mlx.nn.LogSoftmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LSTM.html">mlx.nn.LSTM</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool1d.html">mlx.nn.MaxPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool2d.html">mlx.nn.MaxPool2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool3d.html">mlx.nn.MaxPool3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedEmbedding.html">mlx.nn.QuantizedEmbedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU6.html">mlx.nn.ReLU6</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RNN.html">mlx.nn.RNN</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sigmoid.html">mlx.nn.Sigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmin.html">mlx.nn.Softmin</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softshrink.html">mlx.nn.Softshrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softsign.html">mlx.nn.Softsign</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmax.html">mlx.nn.Softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softplus.html">mlx.nn.Softplus</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Tanh.html">mlx.nn.Tanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Upsample.html">mlx.nn.Upsample</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.elu.html">mlx.nn.elu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.celu.html">mlx.nn.celu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.glu.html">mlx.nn.glu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_shrink.html">mlx.nn.hard_shrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_tanh.html">mlx.nn.hard_tanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hardswish.html">mlx.nn.hardswish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.leaky_relu.html">mlx.nn.leaky_relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_sigmoid.html">mlx.nn.log_sigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_softmax.html">mlx.nn.log_softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu6.html">mlx.nn.relu6</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.sigmoid.html">mlx.nn.sigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmax.html">mlx.nn.softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmin.html">mlx.nn.softmin</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softplus.html">mlx.nn.softplus</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softshrink.html">mlx.nn.softshrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.tanh.html">mlx.nn.tanh</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html">mlx.nn.losses.gaussian_nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html">mlx.nn.losses.margin_ranking_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/init.html">Initializers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.constant.html">mlx.nn.init.constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.normal.html">mlx.nn.init.normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.uniform.html">mlx.nn.init.uniform</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.identity.html">mlx.nn.init.identity</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_normal.html">mlx.nn.init.glorot_normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_uniform.html">mlx.nn.init.glorot_uniform</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_normal.html">mlx.nn.init.he_normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_uniform.html">mlx.nn.init.he_uniform</a></li>
</ul>
</details></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/optimizer.html">Optimizer</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.html">mlx.optimizers.Optimizer.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html">mlx.optimizers.Optimizer.apply_gradients</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.html">mlx.optimizers.Optimizer.init</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.html">mlx.optimizers.Optimizer.update</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/common_optimizers.html">Common Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adafactor.html">mlx.optimizers.Adafactor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdaDelta.html">mlx.optimizers.AdaDelta</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adam.html">mlx.optimizers.Adam</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdamW.html">mlx.optimizers.AdamW</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adamax.html">mlx.optimizers.Adamax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.MultiOptimizer.html">mlx.optimizers.MultiOptimizer</a></li>
</ul>
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/schedulers.html">Schedulers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.cosine_decay.html">mlx.optimizers.cosine_decay</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.exponential_decay.html">mlx.optimizers.exponential_decay</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.join_schedules.html">mlx.optimizers.join_schedules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.linear_schedule.html">mlx.optimizers.linear_schedule</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.step_decay.html">mlx.optimizers.step_decay</a></li>
</ul>
</details></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.clip_grad_norm.html">mlx.optimizers.clip_grad_norm</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/distributed.html">Distributed Communication</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.Group.html">mlx.core.distributed.Group</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.is_available.html">mlx.core.distributed.is_available</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html">mlx.core.distributed.init</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_sum.html">mlx.core.distributed.all_sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_gather.html">mlx.core.distributed.all_gather</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.send.html">mlx.core.distributed.send</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv.html">mlx.core.distributed.recv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv_like.html">mlx.core.distributed.recv_like</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map_with_path.html">mlx.utils.tree_map_with_path</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_reduce.html">mlx.utils.tree_reduce</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../dev/extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/custom_metal_kernels.html">Custom Metal Kernels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dev/mlx_in_cpp.html">Using MLX in C++</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
<div id="rtd-footer-container"></div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="sbt-scroll-pixel-helper"></div>
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="fa-solid fa-bars"></span>
</button></div>
</div>
<div class="header-article-items__end">
<div class="header-article-item">
<div class="article-header-buttons">
<a href="https://github.com/ml-explore/mlx" target="_blank"
class="btn btn-sm btn-source-repository-button"
title="Source repository"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fab fa-github"></i>
</span>
</a>
<div class="dropdown dropdown-download-buttons">
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
<i class="fas fa-download"></i>
</button>
<ul class="dropdown-menu">
<li><a href="../_sources/examples/llama-inference.rst" target="_blank"
class="btn btn-sm btn-download-source-button dropdown-item"
title="Download source file"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file"></i>
</span>
<span class="btn__text-container">.rst</span>
</a>
</li>
<li>
<button onclick="window.print()"
class="btn btn-sm btn-download-pdf-button dropdown-item"
title="Print to PDF"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file-pdf"></i>
</span>
<span class="btn__text-container">.pdf</span>
</button>
</li>
</ul>
</div>
<button onclick="toggleFullScreen()"
class="btn btn-sm btn-fullscreen-button"
title="Fullscreen mode"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-expand"></i>
</span>
</button>
<script>
document.write(`
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
</button>
`);
</script>
<script>
document.write(`
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
</button>
`);
</script>
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="fa-solid fa-list"></span>
</button>
</div></div>
</div>
</div>
</div>
<div id="jb-print-docs-body" class="onlyprint">
<h1>LLM inference</h1>
<!-- Table of contents -->
<div id="print-main-content">
<div id="jb-print-toc">
<div>
<h2> Contents </h2>
</div>
<nav aria-label="Page">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-model">Implementing the model</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-layer">Attention layer</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#encoder-layer">Encoder layer</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#full-model">Full model</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#generation">Generation</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#putting-it-all-together">Putting it all together</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#converting-the-weights">Converting the weights</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#weight-loading-and-benchmarking">Weight loading and benchmarking</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scripts">Scripts</a></li>
</ul>
</nav>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<section id="llm-inference">
<h1>LLM inference<a class="headerlink" href="#llm-inference" title="Link to this heading">#</a></h1>
<p>MLX enables efficient inference of large-ish transformers on Apple silicon
without compromising on ease of use. In this example we will create an
inference script for the Llama family of transformer models in which the model
is defined in less than 200 lines of python.</p>
<section id="implementing-the-model">
<h2>Implementing the model<a class="headerlink" href="#implementing-the-model" title="Link to this heading">#</a></h2>
<p>We will use the neural network building blocks defined in the <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code>
module to concisely define the model architecture.</p>
<section id="attention-layer">
<h3>Attention layer<a class="headerlink" href="#attention-layer" title="Link to this heading">#</a></h3>
<p>We will start with the Llama attention layer which notably uses the RoPE
positional encoding. <a class="footnote-reference brackets" href="#id4" id="id1" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a> In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.</p>
<p>Our implementation uses <a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html#mlx.nn.Linear" title="mlx.nn.Linear"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Linear</span></code></a> for all the projections and
<a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html#mlx.nn.RoPE" title="mlx.nn.RoPE"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.RoPE</span></code></a> for the positional encoding.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">mlx.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="k">class</span><span class="w"> </span><span class="nc">LlamaAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rope</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RoPE</span><span class="p">(</span><span class="n">dims</span> <span class="o">//</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">traditional</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">query_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">queries</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_proj</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span>
<span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
<span class="c1"># Extract some shapes</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span>
<span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">queries</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># Prepare the queries, keys and values for the attention computation</span>
<span class="n">queries</span> <span class="o">=</span> <span class="n">queries</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">keys</span> <span class="o">=</span> <span class="n">keys</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">values</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="c1"># Add RoPE to the queries and keys and combine them with the cache</span>
<span class="k">if</span> <span class="n">cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">key_cache</span><span class="p">,</span> <span class="n">value_cache</span> <span class="o">=</span> <span class="n">cache</span>
<span class="n">queries</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">key_cache</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">key_cache</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="n">keys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">key_cache</span><span class="p">,</span> <span class="n">keys</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">value_cache</span><span class="p">,</span> <span class="n">values</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">queries</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span>
<span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span>
<span class="c1"># Finally perform the attention computation</span>
<span class="n">scale</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">queries</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">scores</span> <span class="o">=</span> <span class="p">(</span><span class="n">queries</span> <span class="o">*</span> <span class="n">scale</span><span class="p">)</span> <span class="o">@</span> <span class="n">keys</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span> <span class="o">+</span> <span class="n">mask</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">values_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">scores</span> <span class="o">@</span> <span class="n">values</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Note that we return the keys and values to possibly be used as a cache</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">values_hat</span><span class="p">),</span> <span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span>
</pre></div>
</div>
</section>
<section id="encoder-layer">
<h3>Encoder layer<a class="headerlink" href="#encoder-layer" title="Link to this heading">#</a></h3>
<p>The other component of the Llama model is the encoder layer which uses RMS
normalization <a class="footnote-reference brackets" href="#id5" id="id2" role="doc-noteref"><span class="fn-bracket">[</span>2<span class="fn-bracket">]</span></a> and SwiGLU. <a class="footnote-reference brackets" href="#id6" id="id3" role="doc-noteref"><span class="fn-bracket">[</span>3<span class="fn-bracket">]</span></a> For RMS normalization we will use
<a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html#mlx.nn.RMSNorm" title="mlx.nn.RMSNorm"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.RMSNorm</span></code></a> that is already provided in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.nn</span></code>.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">LlamaEncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">LlamaAttention</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">linear3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">mlp_dims</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">cache</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="n">mx</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">*</span> <span class="n">b</span>
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear3</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">cache</span>
</pre></div>
</div>
</section>
<section id="full-model">
<h3>Full model<a class="headerlink" href="#full-model" title="Link to this heading">#</a></h3>
<p>To implement any Llama model we simply have to combine <code class="docutils literal notranslate"><span class="pre">LlamaEncoderLayer</span></code>
instances with an <a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html#mlx.nn.Embedding" title="mlx.nn.Embedding"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Embedding</span></code></a> to embed the input tokens.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">LlamaEncoderLayer</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="o">.</span><span class="n">create_additive_causal_mask</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>Note that in the implementation above we use a simple list to hold the encoder
layers but using <code class="docutils literal notranslate"><span class="pre">model.parameters()</span></code> will still consider these layers.</p>
</section>
<section id="generation">
<h3>Generation<a class="headerlink" href="#generation" title="Link to this heading">#</a></h3>
<p>Our <code class="docutils literal notranslate"><span class="pre">Llama</span></code> module can be used for training but not inference as the
<code class="docutils literal notranslate"><span class="pre">__call__</span></code> method above processes one input, completely ignores the cache and
performs no sampling whatsoever. In the rest of this subsection, we will
implement the inference function as a python generator that processes the
prompt and then autoregressively yields tokens one at a time.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Llama</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="o">...</span>
<span class="k">def</span><span class="w"> </span><span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">temp</span><span class="o">=</span><span class="mf">1.0</span><span class="p">):</span>
<span class="n">cache</span> <span class="o">=</span> <span class="p">[]</span>
<span class="c1"># Make an additive causal mask. We will need that to process the prompt.</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="o">.</span><span class="n">create_additive_causal_mask</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># First we process the prompt x the same way as in __call__ but</span>
<span class="c1"># save the caches in cache</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="n">x</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">cache</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># &lt;--- we store the per layer cache in a</span>
<span class="c1"># simple python list</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># &lt;--- we only care about the last logits</span>
<span class="c1"># that generate the next token</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">categorical</span><span class="p">(</span><span class="n">y</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="n">temp</span><span class="p">))</span>
<span class="c1"># y now has size [1]</span>
<span class="c1"># Since MLX is lazily evaluated nothing is computed yet.</span>
<span class="c1"># Calling y.item() would force the computation to happen at</span>
<span class="c1"># this point but we can also choose not to do that and let the</span>
<span class="c1"># user choose when to start the computation.</span>
<span class="k">yield</span> <span class="n">y</span>
<span class="c1"># Now we parsed the prompt and generated the first token we</span>
<span class="c1"># need to feed it back into the model and loop to generate the</span>
<span class="c1"># rest.</span>
<span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
<span class="c1"># Unsqueezing the last dimension to add a sequence length</span>
<span class="c1"># dimension of 1</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">y</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">cache</span><span class="p">)):</span>
<span class="c1"># We are overwriting the arrays in the cache list. When</span>
<span class="c1"># the computation will happen, MLX will be discarding the</span>
<span class="c1"># old cache the moment it is not needed anymore.</span>
<span class="n">x</span><span class="p">,</span> <span class="n">cache</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="n">cache</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">categorical</span><span class="p">(</span><span class="n">y</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="n">temp</span><span class="p">))</span>
<span class="k">yield</span> <span class="n">y</span>
</pre></div>
</div>
</section>
<section id="putting-it-all-together">
<h3>Putting it all together<a class="headerlink" href="#putting-it-all-together" title="Link to this heading">#</a></h3>
<p>We now have everything we need to create a Llama model and sample tokens from
it. In the following code, we randomly initialize a small Llama model, process
6 tokens of prompt and generate 10 tokens.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="n">Llama</span><span class="p">(</span><span class="n">num_layers</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">vocab_size</span><span class="o">=</span><span class="mi">8192</span><span class="p">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">mlp_dims</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="c1"># Since MLX is lazily evaluated nothing has actually been materialized yet.</span>
<span class="c1"># We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the</span>
<span class="c1"># code above would still run. Let&#39;s actually materialize the model.</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="n">prompt</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">44</span><span class="p">,</span> <span class="mi">7</span><span class="p">]])</span> <span class="c1"># &lt;-- Note the double brackets because we</span>
<span class="c1"># have a batch dimension even</span>
<span class="c1"># though it is 1 in this case</span>
<span class="n">generated</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">))]</span>
<span class="c1"># Since we haven&#39;t evaluated anything, nothing is computed yet. The list</span>
<span class="c1"># `generated` contains the arrays that hold the computation graph for the</span>
<span class="c1"># full processing of the prompt and the generation of 10 tokens.</span>
<span class="c1">#</span>
<span class="c1"># We can evaluate them one at a time, or all together. Concatenate them or</span>
<span class="c1"># print them. They would all result in very similar runtimes and give exactly</span>
<span class="c1"># the same results.</span>
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">generated</span><span class="p">)</span>
</pre></div>
</div>
</section>
</section>
<section id="converting-the-weights">
<h2>Converting the weights<a class="headerlink" href="#converting-the-weights" title="Link to this heading">#</a></h2>
<p>This section assumes that you have access to the original Llama weights and the
SentencePiece model that comes with them. We will write a small script to
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
that can be loaded directly by MLX.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">argparse</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">itertools</span><span class="w"> </span><span class="kn">import</span> <span class="n">starmap</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="k">def</span><span class="w"> </span><span class="nf">map_torch_to_mlx</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
<span class="k">if</span> <span class="s2">&quot;tok_embedding&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="n">key</span> <span class="o">=</span> <span class="s2">&quot;embedding.weight&quot;</span>
<span class="k">elif</span> <span class="s2">&quot;norm&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;attention_norm&quot;</span><span class="p">,</span> <span class="s2">&quot;norm1&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;ffn_norm&quot;</span><span class="p">,</span> <span class="s2">&quot;norm2&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="s2">&quot;wq&quot;</span> <span class="ow">in</span> <span class="n">key</span> <span class="ow">or</span> <span class="s2">&quot;wk&quot;</span> <span class="ow">in</span> <span class="n">key</span> <span class="ow">or</span> <span class="s2">&quot;wv&quot;</span> <span class="ow">in</span> <span class="n">key</span> <span class="ow">or</span> <span class="s2">&quot;wo&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;wq&quot;</span><span class="p">,</span> <span class="s2">&quot;query_proj&quot;</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;wk&quot;</span><span class="p">,</span> <span class="s2">&quot;key_proj&quot;</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;wv&quot;</span><span class="p">,</span> <span class="s2">&quot;value_proj&quot;</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;wo&quot;</span><span class="p">,</span> <span class="s2">&quot;out_proj&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="s2">&quot;w1&quot;</span> <span class="ow">in</span> <span class="n">key</span> <span class="ow">or</span> <span class="s2">&quot;w2&quot;</span> <span class="ow">in</span> <span class="n">key</span> <span class="ow">or</span> <span class="s2">&quot;w3&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="c1"># The FFN is a separate submodule in PyTorch</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;feed_forward.w1&quot;</span><span class="p">,</span> <span class="s2">&quot;linear1&quot;</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;feed_forward.w3&quot;</span><span class="p">,</span> <span class="s2">&quot;linear2&quot;</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;feed_forward.w2&quot;</span><span class="p">,</span> <span class="s2">&quot;linear3&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="s2">&quot;output&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;output&quot;</span><span class="p">,</span> <span class="s2">&quot;out_proj&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="s2">&quot;rope&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s2">&quot;Convert Llama weights to MLX&quot;</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;torch_weights&quot;</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;output_file&quot;</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">torch_weights</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">savez</span><span class="p">(</span>
<span class="n">args</span><span class="o">.</span><span class="n">output_file</span><span class="p">,</span>
<span class="o">**</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">starmap</span><span class="p">(</span><span class="n">map_torch_to_mlx</span><span class="p">,</span> <span class="n">state</span><span class="o">.</span><span class="n">items</span><span class="p">())</span> <span class="k">if</span> <span class="n">k</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">}</span>
<span class="p">)</span>
</pre></div>
</div>
</section>
<section id="weight-loading-and-benchmarking">
<h2>Weight loading and benchmarking<a class="headerlink" href="#weight-loading-and-benchmarking" title="Link to this heading">#</a></h2>
<p>After converting the weights to be compatible to our implementation, all that is
left is to load them from disk and we can finally use the LLM to generate text.
We can load numpy format files using the <a class="reference internal" href="../python/_autosummary/mlx.core.load.html#mlx.core.load" title="mlx.core.load"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.core.load()</span></code></a> operation.</p>
<p>To create a parameter dictionary from the key/value representation of NPZ files
we will use the <a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html#mlx.utils.tree_unflatten" title="mlx.utils.tree_unflatten"><code class="xref py py-func docutils literal notranslate"><span class="pre">mlx.utils.tree_unflatten()</span></code></a> helper method as follows:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mlx.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">tree_unflatten</span>
<span class="n">model</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">tree_unflatten</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">weight_file</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">())))</span>
</pre></div>
</div>
<p><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html#mlx.utils.tree_unflatten" title="mlx.utils.tree_unflatten"><code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx.utils.tree_unflatten()</span></code></a> will take keys from the NPZ file that look
like <code class="docutils literal notranslate"><span class="pre">layers.2.attention.query_proj.weight</span></code> and will transform them to</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="p">{</span><span class="s2">&quot;layers&quot;</span><span class="p">:</span> <span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="p">{</span><span class="s2">&quot;attention&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;query_proj&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;weight&quot;</span><span class="p">:</span> <span class="o">...</span><span class="p">}}}]}</span>
</pre></div>
</div>
<p>which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.</p>
<p>You can download the full example code in <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/llms/llama">mlx-examples</a>. Assuming, the
existence of <code class="docutils literal notranslate"><span class="pre">weights.pth</span></code> and <code class="docutils literal notranslate"><span class="pre">tokenizer.model</span></code> in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>python<span class="w"> </span>convert.py<span class="w"> </span>weights.pth<span class="w"> </span>llama-7B.mlx.npz
$<span class="w"> </span>python<span class="w"> </span>llama.py<span class="w"> </span>llama-7B.mlx.npz<span class="w"> </span>tokenizer.model<span class="w"> </span><span class="s1">&#39;Call me Ishmael. Some years ago never mind how long precisely&#39;</span>
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Loading<span class="w"> </span>model<span class="w"> </span>from<span class="w"> </span>disk:<span class="w"> </span><span class="m">5</span>.247<span class="w"> </span>s
Press<span class="w"> </span>enter<span class="w"> </span>to<span class="w"> </span>start<span class="w"> </span>generation
------
,<span class="w"> </span>having<span class="w"> </span>little<span class="w"> </span>or<span class="w"> </span>no<span class="w"> </span>money<span class="w"> </span><span class="k">in</span><span class="w"> </span>my<span class="w"> </span>purse,<span class="w"> </span>and<span class="w"> </span>nothing<span class="w"> </span>of<span class="w"> </span>greater<span class="w"> </span>consequence<span class="w"> </span><span class="k">in</span><span class="w"> </span>my<span class="w"> </span>mind,<span class="w"> </span>I<span class="w"> </span>happened<span class="w"> </span>to<span class="w"> </span>be<span class="w"> </span>walking<span class="w"> </span>down<span class="w"> </span>Gower<span class="w"> </span>Street<span class="w"> </span><span class="k">in</span><span class="w"> </span>the<span class="w"> </span>afternoon,<span class="w"> </span><span class="k">in</span><span class="w"> </span>the<span class="w"> </span>heavy<span class="w"> </span>rain,<span class="w"> </span>and<span class="w"> </span>I<span class="w"> </span>saw<span class="w"> </span>a<span class="w"> </span>few<span class="w"> </span>steps<span class="w"> </span>off,<span class="w"> </span>a<span class="w"> </span>man<span class="w"> </span><span class="k">in</span><span class="w"> </span>rags,<span class="w"> </span>who<span class="w"> </span>sat<span class="w"> </span>upon<span class="w"> </span>his<span class="w"> </span>bundle<span class="w"> </span>and<span class="w"> </span>looked<span class="w"> </span>hard<span class="w"> </span>into<span class="w"> </span>the<span class="w"> </span>wet<span class="w"> </span>as<span class="w"> </span><span class="k">if</span><span class="w"> </span>he<span class="w"> </span>were<span class="w"> </span>going<span class="w"> </span>to<span class="w"> </span>cry.<span class="w"> </span>I<span class="w"> </span>watched<span class="w"> </span>him<span class="w"> </span>attentively<span class="w"> </span><span class="k">for</span><span class="w"> </span>some<span class="w"> </span>time,<span class="w"> </span>and<span class="w"> </span>could<span class="w"> </span>not<span class="w"> </span>but<span class="w"> </span>observe<span class="w"> </span>that,<span class="w"> </span>though<span class="w"> </span>a<span class="w"> </span>numerous<span class="w"> </span>crowd<span class="w"> </span>was<span class="w"> </span>hurrying<span class="w"> </span>up<span class="w"> </span>and<span class="w"> </span>down,
------
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Prompt<span class="w"> </span>processing:<span class="w"> </span><span class="m">0</span>.437<span class="w"> </span>s
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Full<span class="w"> </span>generation:<span class="w"> </span><span class="m">4</span>.330<span class="w"> </span>s
</pre></div>
</div>
<p>We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
of those are spent processing the prompt. This amounts to a little over <strong>39 ms
per token</strong>.</p>
<p>By running with a much bigger prompt we can see that the per token generation
time as well as the prompt processing time remains almost constant.</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>$<span class="w"> </span>python<span class="w"> </span>llama.py<span class="w"> </span>llama-7B.mlx.npz<span class="w"> </span>tokenizer.model<span class="w"> </span><span class="s1">&#39;Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not&#39;</span>
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Loading<span class="w"> </span>model<span class="w"> </span>from<span class="w"> </span>disk:<span class="w"> </span><span class="m">5</span>.247<span class="w"> </span>s
Press<span class="w"> </span>enter<span class="w"> </span>to<span class="w"> </span>start<span class="w"> </span>generation
------
take<span class="w"> </span>his<span class="w"> </span>eyes<span class="w"> </span>from<span class="w"> </span>the<span class="w"> </span>ground.<span class="w"> </span>“What<span class="w"> </span>is<span class="w"> </span>it<span class="w"> </span>you<span class="w"> </span>are<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span>?”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“I<span class="w"> </span>am<span class="w"> </span>not<span class="w"> </span>accustomed<span class="w"> </span>to<span class="w"> </span>be<span class="w"> </span>thus<span class="w"> </span>questioned,”<span class="w"> </span>said<span class="w"> </span>he.<span class="w"> </span>“You<span class="w"> </span>look<span class="w"> </span>like<span class="w"> </span>a<span class="w"> </span>reasonable<span class="w"> </span>man—tell<span class="w"> </span>me,<span class="w"> </span><span class="k">then</span>,<span class="w"> </span>what<span class="w"> </span>are<span class="w"> </span>you<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span>?”<span class="w"> </span>“You<span class="w"> </span>would<span class="w"> </span>not<span class="w"> </span>understand,”<span class="w"> </span>he<span class="w"> </span>replied<span class="p">;</span><span class="w"> </span>“and<span class="w"> </span>how<span class="w"> </span>could<span class="w"> </span>you<span class="w"> </span><span class="nb">help</span><span class="w"> </span>me,<span class="w"> </span><span class="k">if</span><span class="w"> </span>I<span class="w"> </span>were<span class="w"> </span>to<span class="w"> </span>tell<span class="w"> </span>you?”<span class="w"> </span>“I<span class="w"> </span>should<span class="w"> </span>not<span class="w"> </span>only<span class="w"> </span>understand,<span class="w"> </span>but<span class="w"> </span>would<span class="w"> </span><span class="k">do</span><span class="w"> </span>all<span class="w"> </span>that<span class="w"> </span>I<span class="w"> </span>could,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>He<span class="w"> </span>did<span class="w"> </span>not
------
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Prompt<span class="w"> </span>processing:<span class="w"> </span><span class="m">0</span>.579<span class="w"> </span>s
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Full<span class="w"> </span>generation:<span class="w"> </span><span class="m">4</span>.690<span class="w"> </span>s
$<span class="w"> </span>python<span class="w"> </span>llama.py<span class="w"> </span>--num-tokens<span class="w"> </span><span class="m">500</span><span class="w"> </span>llama-7B.mlx.npz<span class="w"> </span>tokenizer.model<span class="w"> </span><span class="s1">&#39;Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not&#39;</span>
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Loading<span class="w"> </span>model<span class="w"> </span>from<span class="w"> </span>disk:<span class="w"> </span><span class="m">5</span>.628<span class="w"> </span>s
Press<span class="w"> </span>enter<span class="w"> </span>to<span class="w"> </span>start<span class="w"> </span>generation
------
take<span class="w"> </span>his<span class="w"> </span>eyes<span class="w"> </span>from<span class="w"> </span>the<span class="w"> </span>ground.<span class="w"> </span>“What<span class="w"> </span>is<span class="w"> </span>it<span class="w"> </span>you<span class="w"> </span>are<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span>?”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“I<span class="w"> </span>am<span class="w"> </span>not<span class="w"> </span>accustomed<span class="w"> </span>to<span class="w"> </span>be<span class="w"> </span>thus<span class="w"> </span>questioned,”<span class="w"> </span>said<span class="w"> </span>he.<span class="w"> </span>“You<span class="w"> </span>look<span class="w"> </span>like<span class="w"> </span>a<span class="w"> </span>reasonable<span class="w"> </span>man—tell<span class="w"> </span>me,<span class="w"> </span><span class="k">then</span>,<span class="w"> </span>what<span class="w"> </span>are<span class="w"> </span>you<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span>?”<span class="w"> </span>“You<span class="w"> </span>would<span class="w"> </span>not<span class="w"> </span>understand,”<span class="w"> </span>he<span class="w"> </span>replied<span class="p">;</span><span class="w"> </span>“and<span class="w"> </span>how<span class="w"> </span>could<span class="w"> </span>you<span class="w"> </span><span class="nb">help</span><span class="w"> </span>me,<span class="w"> </span><span class="k">if</span><span class="w"> </span>I<span class="w"> </span>were<span class="w"> </span>to<span class="w"> </span>tell<span class="w"> </span>you?”<span class="w"> </span>“I<span class="w"> </span>should<span class="w"> </span>not<span class="w"> </span>only<span class="w"> </span>understand,<span class="w"> </span>but<span class="w"> </span>would<span class="w"> </span><span class="k">do</span><span class="w"> </span>all<span class="w"> </span>that<span class="w"> </span>I<span class="w"> </span>could,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>He<span class="w"> </span>did<span class="w"> </span>not<span class="w"> </span>reply,<span class="w"> </span>but<span class="w"> </span>still<span class="w"> </span>went<span class="w"> </span>on<span class="w"> </span>looking<span class="w"> </span>at<span class="w"> </span>the<span class="w"> </span>ground,<span class="w"> </span>and<span class="w"> </span>took<span class="w"> </span>hold<span class="w"> </span>of<span class="w"> </span>his<span class="w"> </span>bundle<span class="w"> </span>with<span class="w"> </span>a<span class="w"> </span>nervous<span class="w"> </span>trembling.<span class="w"> </span>I<span class="w"> </span>waited<span class="w"> </span>some<span class="w"> </span>time,<span class="w"> </span>and<span class="w"> </span><span class="k">then</span><span class="w"> </span>resumed.<span class="w"> </span>“It<span class="w"> </span>is<span class="w"> </span>of<span class="w"> </span>no<span class="w"> </span>use<span class="w"> </span>to<span class="w"> </span>say<span class="w"> </span>you<span class="w"> </span>would<span class="w"> </span>not<span class="w"> </span>understand,<span class="w"> </span><span class="k">if</span><span class="w"> </span>I<span class="w"> </span>were<span class="w"> </span>to<span class="w"> </span>tell<span class="w"> </span>you,”<span class="w"> </span>said<span class="w"> </span>he.<span class="w"> </span>“I<span class="w"> </span>have<span class="w"> </span>not<span class="w"> </span>told<span class="w"> </span>you<span class="w"> </span>why<span class="w"> </span>I<span class="w"> </span>am<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span><span class="w"> </span>him,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“And<span class="w"> </span>I<span class="w"> </span>am<span class="w"> </span>sure<span class="w"> </span>I<span class="w"> </span>should<span class="w"> </span>not<span class="w"> </span>understand,”<span class="w"> </span>replied<span class="w"> </span>he.<span class="w"> </span>“I<span class="w"> </span>will<span class="w"> </span>tell<span class="w"> </span>you<span class="w"> </span><span class="k">then</span>,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“and,<span class="w"> </span>perhaps,<span class="w"> </span>you<span class="w"> </span>would<span class="w"> </span>not<span class="w"> </span>be<span class="w"> </span>surprised.”<span class="w"> </span>“No<span class="w"> </span>matter,”<span class="w"> </span>said<span class="w"> </span>he,<span class="w"> </span>“I<span class="w"> </span>shall<span class="w"> </span>be<span class="w"> </span>surprised<span class="w"> </span>anyhow<span class="p">;</span><span class="w"> </span>so<span class="w"> </span>tell<span class="w"> </span>me<span class="w"> </span>why<span class="w"> </span>you<span class="w"> </span>are<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span><span class="w"> </span>him.”<span class="w"> </span>“He<span class="w"> </span>is<span class="w"> </span>my<span class="w"> </span>friend,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“Yes,”<span class="w"> </span>said<span class="w"> </span>he,<span class="w"> </span>with<span class="w"> </span>a<span class="w"> </span>slight<span class="w"> </span>smile,<span class="w"> </span>“I<span class="w"> </span>know.”<span class="w"> </span>“He<span class="w"> </span>has<span class="w"> </span>been<span class="w"> </span>kind<span class="w"> </span>to<span class="w"> </span>me,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“and<span class="w"> </span>I<span class="w"> </span>am<span class="w"> </span>waiting<span class="w"> </span><span class="k">for</span><span class="w"> </span>him.<span class="w"> </span>I<span class="w"> </span>want<span class="w"> </span>to<span class="w"> </span>see<span class="w"> </span>him,<span class="w"> </span>and<span class="w"> </span>could<span class="w"> </span>have<span class="w"> </span>waited<span class="w"> </span>as<span class="w"> </span>I<span class="w"> </span>am<span class="w"> </span>now,<span class="w"> </span><span class="k">for</span><span class="w"> </span>a<span class="w"> </span>much<span class="w"> </span>longer<span class="w"> </span>time.”<span class="w"> </span>“He<span class="w"> </span>will<span class="w"> </span>not<span class="w"> </span>soon<span class="w"> </span>come,”<span class="w"> </span>said<span class="w"> </span>he.<span class="w"> </span>“Unless<span class="w"> </span>he<span class="w"> </span>sees<span class="w"> </span>you<span class="w"> </span>here,<span class="w"> </span>he<span class="w"> </span>will<span class="w"> </span>not<span class="w"> </span>know<span class="w"> </span>of<span class="w"> </span>your<span class="w"> </span>having<span class="w"> </span>waited,<span class="w"> </span>and<span class="w"> </span>he<span class="w"> </span>will<span class="w"> </span>be<span class="w"> </span>very<span class="w"> </span>unlikely<span class="w"> </span>to<span class="w"> </span>come.”<span class="w"> </span>“No<span class="w"> </span>matter,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“I<span class="w"> </span>shall<span class="w"> </span><span class="nb">wait</span><span class="w"> </span><span class="k">for</span><span class="w"> </span>him.”<span class="w"> </span>“This<span class="w"> </span>is<span class="w"> </span>a<span class="w"> </span>strange<span class="w"> </span>thing,”<span class="w"> </span>said<span class="w"> </span>he,<span class="w"> </span>still<span class="w"> </span>with<span class="w"> </span>the<span class="w"> </span>same<span class="w"> </span>amused<span class="w"> </span>smile.<span class="w"> </span>“How<span class="w"> </span>did<span class="w"> </span>you<span class="w"> </span>know,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“that<span class="w"> </span>he<span class="w"> </span>was<span class="w"> </span>coming?<span class="w"> </span>How<span class="w"> </span>should<span class="w"> </span>you<span class="w"> </span>be<span class="w"> </span>waiting?”<span class="w"> </span>“That<span class="w"> </span>is<span class="w"> </span>my<span class="w"> </span>secret,”<span class="w"> </span>said<span class="w"> </span>he.<span class="w"> </span>“And<span class="w"> </span>you<span class="w"> </span>expect<span class="w"> </span>him?”<span class="w"> </span>“Yes,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“Are<span class="w"> </span>you<span class="w"> </span>disappointed<span class="w"> </span><span class="k">then</span>,<span class="w"> </span><span class="k">if</span><span class="w"> </span>he<span class="w"> </span>does<span class="w"> </span>not<span class="w"> </span>come?”<span class="w"> </span>“No,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“it<span class="w"> </span>is<span class="w"> </span>his<span class="w"> </span>secret,<span class="w"> </span>not<span class="w"> </span>mine.”<span class="w"> </span>“If<span class="w"> </span>he<span class="w"> </span>comes,”<span class="w"> </span>said<span class="w"> </span>he,<span class="w"> </span>“do<span class="w"> </span>you<span class="w"> </span>mean<span class="w"> </span>to<span class="w"> </span>go<span class="w"> </span>straight<span class="w"> </span>away?”<span class="w"> </span>“Yes,”<span class="w"> </span>said<span class="w"> </span>I,<span class="w"> </span>“I<span class="w"> </span>cannot<span class="w"> </span>be<span class="w"> </span>happy<span class="w"> </span><span class="k">if</span><span class="w"> </span>I<span class="w"> </span><span class="k">do</span><span class="w"> </span>not<span class="w"> </span>go<span class="w"> </span>straight<span class="w"> </span>away<span class="w"> </span>after<span class="w"> </span>him.”<span class="w"> </span>“Did<span class="w"> </span>you<span class="w"> </span>know<span class="w"> </span>this<span class="w"> </span>place<span class="w"> </span>before?”<span class="w"> </span>asked<span class="w"> </span>he.<span class="w"> </span>“Yes,”<span class="w"> </span>said<span class="w"> </span>I.<span class="w"> </span>“Is<span class="w"> </span>there<span class="w"> </span>any<span class="w"> </span>shop<span class="w"> </span>to<span class="w"> </span>buy<span class="w"> </span>food<span class="w"> </span>here?”<span class="w"> </span>
------
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Prompt<span class="w"> </span>processing:<span class="w"> </span><span class="m">0</span>.633<span class="w"> </span>s
<span class="o">[</span>INFO<span class="o">]</span><span class="w"> </span>Full<span class="w"> </span>generation:<span class="w"> </span><span class="m">21</span>.475<span class="w"> </span>s
</pre></div>
</div>
</section>
<section id="scripts">
<h2>Scripts<a class="headerlink" href="#scripts" title="Link to this heading">#</a></h2>
<div class="admonition-download-the-code admonition">
<p class="admonition-title">Download the code</p>
<p>The full example code is available in <a class="reference external" href="https://github.com/ml-explore/mlx-examples/tree/main/llms/llama">mlx-examples</a>.</p>
</div>
<aside class="footnote-list brackets">
<aside class="footnote brackets" id="id4" role="doc-footnote">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id1">1</a><span class="fn-bracket">]</span></span>
<p>Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv
preprint arXiv:2104.09864.</p>
</aside>
<aside class="footnote brackets" id="id5" role="doc-footnote">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id2">2</a><span class="fn-bracket">]</span></span>
<p>Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
Advances in Neural Information Processing Systems, 32.</p>
</aside>
<aside class="footnote brackets" id="id6" role="doc-footnote">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#id3">3</a><span class="fn-bracket">]</span></span>
<p>Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
arXiv:2002.05202.</p>
</aside>
</aside>
</section>
</section>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
<a class="left-prev"
href="mlp.html"
title="previous page">
<i class="fa-solid fa-angle-left"></i>
<div class="prev-next-info">
<p class="prev-next-subtitle">previous</p>
<p class="prev-next-title">Multi-Layer Perceptron</p>
</div>
</a>
<a class="right-next"
href="../python/array.html"
title="next page">
<div class="prev-next-info">
<p class="prev-next-subtitle">next</p>
<p class="prev-next-title">Array</p>
</div>
<i class="fa-solid fa-angle-right"></i>
</a>
</div>
</footer>
</div>
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
<div class="sidebar-secondary-item">
<div class="page-toc tocsection onthispage">
<i class="fa-solid fa-list"></i> Contents
</div>
<nav class="bd-toc-nav page-toc">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-model">Implementing the model</a><ul class="visible nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-layer">Attention layer</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#encoder-layer">Encoder layer</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#full-model">Full model</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#generation">Generation</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#putting-it-all-together">Putting it all together</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#converting-the-weights">Converting the weights</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#weight-loading-and-benchmarking">Weight loading and benchmarking</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scripts">Scripts</a></li>
</ul>
</nav></div>
</div></div>
</div>
<footer class="bd-footer-content">
<div class="bd-footer-content__inner container">
<div class="footer-item">
<p class="component-author">
By MLX Contributors
</p>
</div>
<div class="footer-item">
<p class="copyright">
© Copyright 2023, Apple.
<br/>
</p>
</div>
<div class="footer-item">
</div>
<div class="footer-item">
</div>
</div>
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
<footer class="bd-footer">
</footer>
</body>
</html>