mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
1813 lines
201 KiB
HTML
1813 lines
201 KiB
HTML
|
||
<!DOCTYPE html>
|
||
|
||
|
||
<html lang="en" data-content_root="../" >
|
||
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<title>Custom Extensions in MLX — MLX 0.26.1 documentation</title>
|
||
|
||
|
||
|
||
<script data-cfasync="false">
|
||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||
</script>
|
||
|
||
<!-- Loaded before other Sphinx assets -->
|
||
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
|
||
|
||
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
|
||
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
|
||
|
||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<script src="../_static/documentation_options.js?v=3724ff34"></script>
|
||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||
<script>DOCUMENTATION_OPTIONS.pagename = 'dev/extensions';</script>
|
||
<link rel="icon" href="../_static/mlx_logo.png"/>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="Metal Debugger" href="metal_debugger.html" />
|
||
<link rel="prev" title="Operations" href="../cpp/ops.html" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<meta name="docsearch:language" content="en"/>
|
||
</head>
|
||
|
||
|
||
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
||
|
||
|
||
|
||
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
||
|
||
<div id="pst-scroll-pixel-helper"></div>
|
||
|
||
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-primary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-secondary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
|
||
|
||
<div class="search-button__wrapper">
|
||
<div class="search-button__overlay"></div>
|
||
<div class="search-button__search-container">
|
||
<form class="bd-search d-flex align-items-center"
|
||
action="../search.html"
|
||
method="get">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<input type="search"
|
||
class="form-control"
|
||
name="q"
|
||
id="search-input"
|
||
placeholder="Search..."
|
||
aria-label="Search..."
|
||
autocomplete="off"
|
||
autocorrect="off"
|
||
autocapitalize="off"
|
||
spellcheck="false"/>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||
</form></div>
|
||
</div>
|
||
|
||
<div class="pst-async-banner-revealer d-none">
|
||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||
</div>
|
||
|
||
|
||
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
||
</header>
|
||
|
||
|
||
<div class="bd-container">
|
||
<div class="bd-container__inner bd-page-width">
|
||
|
||
|
||
|
||
<div class="bd-sidebar-primary bd-sidebar">
|
||
|
||
|
||
|
||
<div class="sidebar-header-items sidebar-primary__section">
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
<div class="sidebar-primary-items__start sidebar-primary__section">
|
||
<div class="sidebar-primary-item">
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../index.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.1 documentation - Home"/>
|
||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/>`);</script>
|
||
|
||
|
||
</a></div>
|
||
<div class="sidebar-primary-item">
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
`);
|
||
</script></div>
|
||
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
|
||
<div class="bd-toc-item navbar-nav active">
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Install</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Usage</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/quick_start.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/lazy_evaluation.html">Lazy Evaluation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/unified_memory.html">Unified Memory</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/function_transforms.html">Function Transforms</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/compile.html">Compilation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../examples/linear_regression.html">Linear Regression</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../examples/mlp.html">Multi-Layer Perceptron</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../examples/llama-inference.html">LLM inference</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/array.html">Array</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.html">mlx.core.array</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.astype.html">mlx.core.array.astype</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.at.html">mlx.core.array.at</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html">mlx.core.array.item</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.tolist.html">mlx.core.array.tolist</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.dtype.html">mlx.core.array.dtype</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.itemsize.html">mlx.core.array.itemsize</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.nbytes.html">mlx.core.array.nbytes</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.ndim.html">mlx.core.array.ndim</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.shape.html">mlx.core.array.shape</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.size.html">mlx.core.array.size</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.real.html">mlx.core.array.real</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.imag.html">mlx.core.array.imag</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.abs.html">mlx.core.array.abs</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.all.html">mlx.core.array.all</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.any.html">mlx.core.array.any</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmax.html">mlx.core.array.argmax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmin.html">mlx.core.array.argmin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.conj.html">mlx.core.array.conj</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cos.html">mlx.core.array.cos</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummax.html">mlx.core.array.cummax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummin.html">mlx.core.array.cummin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumprod.html">mlx.core.array.cumprod</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumsum.html">mlx.core.array.cumsum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diag.html">mlx.core.array.diag</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diagonal.html">mlx.core.array.diagonal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.exp.html">mlx.core.array.exp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.flatten.html">mlx.core.array.flatten</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log.html">mlx.core.array.log</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log10.html">mlx.core.array.log10</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log1p.html">mlx.core.array.log1p</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log2.html">mlx.core.array.log2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logcumsumexp.html">mlx.core.array.logcumsumexp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logsumexp.html">mlx.core.array.logsumexp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.max.html">mlx.core.array.max</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.mean.html">mlx.core.array.mean</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.min.html">mlx.core.array.min</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.moveaxis.html">mlx.core.array.moveaxis</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.prod.html">mlx.core.array.prod</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reciprocal.html">mlx.core.array.reciprocal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reshape.html">mlx.core.array.reshape</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.round.html">mlx.core.array.round</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.rsqrt.html">mlx.core.array.rsqrt</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sin.html">mlx.core.array.sin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.split.html">mlx.core.array.split</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sqrt.html">mlx.core.array.sqrt</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.square.html">mlx.core.array.square</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.squeeze.html">mlx.core.array.squeeze</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.std.html">mlx.core.array.std</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sum.html">mlx.core.array.sum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.swapaxes.html">mlx.core.array.swapaxes</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.transpose.html">mlx.core.array.transpose</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.T.html">mlx.core.array.T</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.var.html">mlx.core.array.var</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.view.html">mlx.core.array.view</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/data_types.html">Data Types</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Device.html">mlx.core.Device</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/stream_class.html">mlx.core.Stream</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_device.html">mlx.core.default_device</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_device.html">mlx.core.set_default_device</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_stream.html">mlx.core.default_stream</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.new_stream.html">mlx.core.new_stream</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_stream.html">mlx.core.set_default_stream</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stream.html">mlx.core.stream</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.addmm.html">mlx.core.addmm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.all.html">mlx.core.all</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.allclose.html">mlx.core.allclose</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.any.html">mlx.core.any</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arange.html">mlx.core.arange</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccos.html">mlx.core.arccos</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccosh.html">mlx.core.arccosh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsin.html">mlx.core.arcsin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsinh.html">mlx.core.arcsinh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan.html">mlx.core.arctan</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan2.html">mlx.core.arctan2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctanh.html">mlx.core.arctanh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmax.html">mlx.core.argmax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmin.html">mlx.core.argmin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argpartition.html">mlx.core.argpartition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argsort.html">mlx.core.argsort</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array_equal.html">mlx.core.array_equal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.as_strided.html">mlx.core.as_strided</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_1d.html">mlx.core.atleast_1d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_2d.html">mlx.core.atleast_2d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_3d.html">mlx.core.atleast_3d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_and.html">mlx.core.bitwise_and</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_invert.html">mlx.core.bitwise_invert</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_or.html">mlx.core.bitwise_or</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_xor.html">mlx.core.bitwise_xor</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.block_masked_mm.html">mlx.core.block_masked_mm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_arrays.html">mlx.core.broadcast_arrays</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clip.html">mlx.core.clip</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.contiguous.html">mlx.core.contiguous</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conj.html">mlx.core.conj</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conjugate.html">mlx.core.conjugate</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv1d.html">mlx.core.conv1d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv2d.html">mlx.core.conv2d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv3d.html">mlx.core.conv3d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose1d.html">mlx.core.conv_transpose1d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose2d.html">mlx.core.conv_transpose2d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose3d.html">mlx.core.conv_transpose3d</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_general.html">mlx.core.conv_general</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cos.html">mlx.core.cos</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cosh.html">mlx.core.cosh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummax.html">mlx.core.cummax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummin.html">mlx.core.cummin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumprod.html">mlx.core.cumprod</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumsum.html">mlx.core.cumsum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.degrees.html">mlx.core.degrees</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.dequantize.html">mlx.core.dequantize</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diag.html">mlx.core.diag</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diagonal.html">mlx.core.diagonal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divide.html">mlx.core.divide</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divmod.html">mlx.core.divmod</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum.html">mlx.core.einsum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum_path.html">mlx.core.einsum_path</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.equal.html">mlx.core.equal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor.html">mlx.core.floor</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor_divide.html">mlx.core.floor_divide</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.full.html">mlx.core.full</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_mm.html">mlx.core.gather_mm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_qmm.html">mlx.core.gather_qmm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater.html">mlx.core.greater</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.hadamard_transform.html">mlx.core.hadamard_transform</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.imag.html">mlx.core.imag</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isfinite.html">mlx.core.isfinite</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isclose.html">mlx.core.isclose</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isinf.html">mlx.core.isinf</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isnan.html">mlx.core.isnan</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linspace.html">mlx.core.linspace</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.load.html">mlx.core.load</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log.html">mlx.core.log</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log2.html">mlx.core.log2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log10.html">mlx.core.log10</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log1p.html">mlx.core.log1p</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logaddexp.html">mlx.core.logaddexp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logcumsumexp.html">mlx.core.logcumsumexp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_not.html">mlx.core.logical_not</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_and.html">mlx.core.logical_and</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_or.html">mlx.core.logical_or</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logsumexp.html">mlx.core.logsumexp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.matmul.html">mlx.core.matmul</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.multiply.html">mlx.core.multiply</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.nan_to_num.html">mlx.core.nan_to_num</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.negative.html">mlx.core.negative</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.not_equal.html">mlx.core.not_equal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones.html">mlx.core.ones</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones_like.html">mlx.core.ones_like</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.outer.html">mlx.core.outer</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.partition.html">mlx.core.partition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.pad.html">mlx.core.pad</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.power.html">mlx.core.power</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.prod.html">mlx.core.prod</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.put_along_axis.html">mlx.core.put_along_axis</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.radians.html">mlx.core.radians</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.real.html">mlx.core.real</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.remainder.html">mlx.core.remainder</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.right_shift.html">mlx.core.right_shift</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.roll.html">mlx.core.roll</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.round.html">mlx.core.round</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_gguf.html">mlx.core.save_gguf</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sqrt.html">mlx.core.sqrt</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.swapaxes.html">mlx.core.swapaxes</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take.html">mlx.core.take</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tile.html">mlx.core.tile</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.topk.html">mlx.core.topk</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.trace.html">mlx.core.trace</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros.html">mlx.core.zeros</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros_like.html">mlx.core.zeros_like</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/random.html">Random</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.bernoulli.html">mlx.core.random.bernoulli</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.categorical.html">mlx.core.random.categorical</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.truncated_normal.html">mlx.core.random.truncated_normal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.uniform.html">mlx.core.random.uniform</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.laplace.html">mlx.core.random.laplace</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.permutation.html">mlx.core.random.permutation</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.async_eval.html">mlx.core.async_eval</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html">mlx.core.custom_function</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.enable_compile.html">mlx.core.enable_compile</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.grad.html">mlx.core.grad</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html">mlx.core.value_and_grad</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vjp.html">mlx.core.vjp</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html">mlx.core.vmap</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fast.html">Fast</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rms_norm.html">mlx.core.fast.rms_norm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.layer_norm.html">mlx.core.fast.layer_norm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rope.html">mlx.core.fast.rope</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html">mlx.core.fast.scaled_dot_product_attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html">mlx.core.fast.metal_kernel</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fft.html">FFT</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft.html">mlx.core.fft.fft</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft.html">mlx.core.fft.ifft</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft2.html">mlx.core.fft.fft2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft2.html">mlx.core.fft.ifft2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftn.html">mlx.core.fft.fftn</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftn.html">mlx.core.fft.ifftn</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft.html">mlx.core.fft.rfft</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft.html">mlx.core.fft.irfft</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft2.html">mlx.core.fft.rfft2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft2.html">mlx.core.fft.irfft2</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfftn.html">mlx.core.fft.rfftn</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftshift.html">mlx.core.fft.fftshift</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftshift.html">mlx.core.fft.ifftshift</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/linalg.html">Linear Algebra</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.inv.html">mlx.core.linalg.inv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.tri_inv.html">mlx.core.linalg.tri_inv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky.html">mlx.core.linalg.cholesky</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky_inv.html">mlx.core.linalg.cholesky_inv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cross.html">mlx.core.linalg.cross</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.qr.html">mlx.core.linalg.qr</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.svd.html">mlx.core.linalg.svd</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvals.html">mlx.core.linalg.eigvals</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eig.html">mlx.core.linalg.eig</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvalsh.html">mlx.core.linalg.eigvalsh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigh.html">mlx.core.linalg.eigh</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu.html">mlx.core.linalg.lu</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu_factor.html">mlx.core.linalg.lu_factor</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.pinv.html">mlx.core.linalg.pinv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve.html">mlx.core.linalg.solve</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve_triangular.html">mlx.core.linalg.solve_triangular</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/metal.html">Metal</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.is_available.html">mlx.core.metal.is_available</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.device_info.html">mlx.core.metal.device_info</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/memory_management.html">Memory Management</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_active_memory.html">mlx.core.get_active_memory</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_peak_memory.html">mlx.core.get_peak_memory</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reset_peak_memory.html">mlx.core.reset_peak_memory</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_cache_memory.html">mlx.core.get_cache_memory</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_memory_limit.html">mlx.core.set_memory_limit</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_cache_limit.html">mlx.core.set_cache_limit</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_wired_limit.html">mlx.core.set_wired_limit</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clear_cache.html">mlx.core.clear_cache</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.quantize.html">mlx.nn.quantize</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.average_gradients.html">mlx.nn.average_gradients</a></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.set_dtype.html">mlx.nn.Module.set_dtype</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool1d.html">mlx.nn.AvgPool1d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool2d.html">mlx.nn.AvgPool2d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool3d.html">mlx.nn.AvgPool3d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.CELU.html">mlx.nn.CELU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv3d.html">mlx.nn.Conv3d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html">mlx.nn.ConvTranspose1d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html">mlx.nn.ConvTranspose2d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html">mlx.nn.ConvTranspose3d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ELU.html">mlx.nn.ELU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GLU.html">mlx.nn.GLU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GRU.html">mlx.nn.GRU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardShrink.html">mlx.nn.HardShrink</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardTanh.html">mlx.nn.HardTanh</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Hardswish.html">mlx.nn.Hardswish</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LeakyReLU.html">mlx.nn.LeakyReLU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSigmoid.html">mlx.nn.LogSigmoid</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSoftmax.html">mlx.nn.LogSoftmax</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LSTM.html">mlx.nn.LSTM</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool1d.html">mlx.nn.MaxPool1d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool2d.html">mlx.nn.MaxPool2d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool3d.html">mlx.nn.MaxPool3d</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedEmbedding.html">mlx.nn.QuantizedEmbedding</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU6.html">mlx.nn.ReLU6</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RNN.html">mlx.nn.RNN</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sigmoid.html">mlx.nn.Sigmoid</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmin.html">mlx.nn.Softmin</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softshrink.html">mlx.nn.Softshrink</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softsign.html">mlx.nn.Softsign</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmax.html">mlx.nn.Softmax</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softplus.html">mlx.nn.Softplus</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Tanh.html">mlx.nn.Tanh</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Upsample.html">mlx.nn.Upsample</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.elu.html">mlx.nn.elu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.celu.html">mlx.nn.celu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.glu.html">mlx.nn.glu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_shrink.html">mlx.nn.hard_shrink</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_tanh.html">mlx.nn.hard_tanh</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hardswish.html">mlx.nn.hardswish</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.leaky_relu.html">mlx.nn.leaky_relu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_sigmoid.html">mlx.nn.log_sigmoid</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_softmax.html">mlx.nn.log_softmax</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu6.html">mlx.nn.relu6</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.sigmoid.html">mlx.nn.sigmoid</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmax.html">mlx.nn.softmax</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmin.html">mlx.nn.softmin</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softplus.html">mlx.nn.softplus</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softshrink.html">mlx.nn.softshrink</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.tanh.html">mlx.nn.tanh</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html">mlx.nn.losses.gaussian_nll_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html">mlx.nn.losses.margin_ranking_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/init.html">Initializers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.constant.html">mlx.nn.init.constant</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.normal.html">mlx.nn.init.normal</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.uniform.html">mlx.nn.init.uniform</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.identity.html">mlx.nn.init.identity</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_normal.html">mlx.nn.init.glorot_normal</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_uniform.html">mlx.nn.init.glorot_uniform</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_normal.html">mlx.nn.init.he_normal</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_uniform.html">mlx.nn.init.he_uniform</a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/optimizer.html">Optimizer</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.html">mlx.optimizers.Optimizer.state</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html">mlx.optimizers.Optimizer.apply_gradients</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.html">mlx.optimizers.Optimizer.init</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.html">mlx.optimizers.Optimizer.update</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/common_optimizers.html">Common Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adafactor.html">mlx.optimizers.Adafactor</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdaDelta.html">mlx.optimizers.AdaDelta</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adam.html">mlx.optimizers.Adam</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdamW.html">mlx.optimizers.AdamW</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adamax.html">mlx.optimizers.Adamax</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.MultiOptimizer.html">mlx.optimizers.MultiOptimizer</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/schedulers.html">Schedulers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.cosine_decay.html">mlx.optimizers.cosine_decay</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.exponential_decay.html">mlx.optimizers.exponential_decay</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.join_schedules.html">mlx.optimizers.join_schedules</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.linear_schedule.html">mlx.optimizers.linear_schedule</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.step_decay.html">mlx.optimizers.step_decay</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.clip_grad_norm.html">mlx.optimizers.clip_grad_norm</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/distributed.html">Distributed Communication</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.Group.html">mlx.core.distributed.Group</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.is_available.html">mlx.core.distributed.is_available</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html">mlx.core.distributed.init</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_sum.html">mlx.core.distributed.all_sum</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_gather.html">mlx.core.distributed.all_gather</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.send.html">mlx.core.distributed.send</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv.html">mlx.core.distributed.recv</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv_like.html">mlx.core.distributed.recv_like</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map_with_path.html">mlx.utils.tree_map_with_path</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_reduce.html">mlx.utils.tree_reduce</a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
|
||
<ul class="current nav bd-sidenav">
|
||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Custom Extensions in MLX</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="metal_debugger.html">Metal Debugger</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="custom_metal_kernels.html">Custom Metal Kernels</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="mlx_in_cpp.html">Using MLX in C++</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</nav></div>
|
||
</div>
|
||
|
||
|
||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||
</div>
|
||
|
||
<div id="rtd-footer-container"></div>
|
||
|
||
|
||
</div>
|
||
|
||
<main id="main-content" class="bd-main" role="main">
|
||
|
||
|
||
|
||
<div class="sbt-scroll-pixel-helper"></div>
|
||
|
||
<div class="bd-content">
|
||
<div class="bd-article-container">
|
||
|
||
<div class="bd-header-article d-print-none">
|
||
<div class="header-article-items header-article__inner">
|
||
|
||
<div class="header-article-items__start">
|
||
|
||
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-bars"></span>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="header-article-items__end">
|
||
|
||
<div class="header-article-item">
|
||
|
||
<div class="article-header-buttons">
|
||
|
||
|
||
<a href="https://github.com/ml-explore/mlx" target="_blank"
|
||
class="btn btn-sm btn-source-repository-button"
|
||
title="Source repository"
|
||
data-bs-placement="bottom" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fab fa-github"></i>
|
||
</span>
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div class="dropdown dropdown-download-buttons">
|
||
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
|
||
<i class="fas fa-download"></i>
|
||
</button>
|
||
<ul class="dropdown-menu">
|
||
|
||
|
||
|
||
<li><a href="../_sources/dev/extensions.rst" target="_blank"
|
||
class="btn btn-sm btn-download-source-button dropdown-item"
|
||
title="Download source file"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file"></i>
|
||
</span>
|
||
<span class="btn__text-container">.rst</span>
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
<li>
|
||
<button onclick="window.print()"
|
||
class="btn btn-sm btn-download-pdf-button dropdown-item"
|
||
title="Print to PDF"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file-pdf"></i>
|
||
</span>
|
||
<span class="btn__text-container">.pdf</span>
|
||
</button>
|
||
</li>
|
||
|
||
</ul>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<button onclick="toggleFullScreen()"
|
||
class="btn btn-sm btn-fullscreen-button"
|
||
title="Fullscreen mode"
|
||
data-bs-placement="bottom" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-expand"></i>
|
||
</span>
|
||
|
||
</button>
|
||
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-list"></span>
|
||
</button>
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="jb-print-docs-body" class="onlyprint">
|
||
<h1>Custom Extensions in MLX</h1>
|
||
<!-- Table of contents -->
|
||
<div id="print-main-content">
|
||
<div id="jb-print-toc">
|
||
|
||
<div>
|
||
<h2> Contents </h2>
|
||
</div>
|
||
<nav aria-label="Page">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#introducing-the-example">Introducing the Example</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#building-and-binding">Building and Binding</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#binding-to-python">Binding to Python</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#building-with-cmake">Building with CMake</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#building-with-setuptools">Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code></a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#usage">Usage</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#results">Results</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scripts">Scripts</a></li>
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="searchbox"></div>
|
||
<article class="bd-article">
|
||
|
||
<section id="custom-extensions-in-mlx">
|
||
<h1>Custom Extensions in MLX<a class="headerlink" href="#custom-extensions-in-mlx" title="Link to this heading">#</a></h1>
|
||
<p>You can extend MLX with custom operations on the CPU or GPU. This guide
|
||
explains how to do that with a simple example.</p>
|
||
<section id="introducing-the-example">
|
||
<h2>Introducing the Example<a class="headerlink" href="#introducing-the-example" title="Link to this heading">#</a></h2>
|
||
<p>Let’s say you would like an operation that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, scales them both by coefficients <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> respectively,
|
||
and then adds them together to get the result <code class="docutils literal notranslate"><span class="pre">z</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">x</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">y</span></code>.
|
||
You can do that in MLX directly:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>This function performs that operation while leaving the implementation and
|
||
function transformations to MLX.</p>
|
||
<p>However, you may want to customize the underlying implementation, perhaps to
|
||
make it faster. In this tutorial we will go through adding custom extensions.
|
||
It will cover:</p>
|
||
<ul class="simple">
|
||
<li><p>The structure of the MLX library.</p></li>
|
||
<li><p>Implementing a CPU operation.</p></li>
|
||
<li><p>Implementing a GPU operation using metal.</p></li>
|
||
<li><p>Adding the <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code> function transformation.</p></li>
|
||
<li><p>Building a custom extension and binding it to python.</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="operations-and-primitives">
|
||
<h2>Operations and Primitives<a class="headerlink" href="#operations-and-primitives" title="Link to this heading">#</a></h2>
|
||
<p>Operations in MLX build the computation graph. Primitives provide the rules for
|
||
evaluating and transforming the graph. Let’s start by discussing operations in
|
||
more detail.</p>
|
||
<section id="operations">
|
||
<h3>Operations<a class="headerlink" href="#operations" title="Link to this heading">#</a></h3>
|
||
<p>Operations are the front-end functions that operate on arrays. They are defined
|
||
in the C++ API (<a class="reference internal" href="../cpp/ops.html#cpp-ops"><span class="std std-ref">Operations</span></a>), and the Python API (<a class="reference internal" href="../python/ops.html#ops"><span class="std std-ref">Operations</span></a>) binds them.</p>
|
||
<p>We would like an operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> that takes in two arrays, <code class="docutils literal notranslate"><span class="pre">x</span></code> and
|
||
<code class="docutils literal notranslate"><span class="pre">y</span></code>, and two scalars, <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code>. This is how to define it in
|
||
C++:</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/**</span>
|
||
<span class="cm">* Scale and sum two vectors element-wise</span>
|
||
<span class="cm">* z = alpha * x + beta * y</span>
|
||
<span class="cm">*</span>
|
||
<span class="cm">* Use NumPy-style broadcasting between x and y</span>
|
||
<span class="cm">* Inputs are upcasted to floats if needed</span>
|
||
<span class="cm">**/</span>
|
||
<span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">{}</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||
<span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The simplest way to implement this is with existing operations:</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="cm">/* = {} */</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||
<span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Scale x and y on the provided stream</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">ax</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">by</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">multiply</span><span class="p">(</span><span class="n">array</span><span class="p">(</span><span class="n">beta</span><span class="p">),</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Add and return</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">add</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span><span class="w"> </span><span class="n">by</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The operations themselves do not contain the implementations that act on the
|
||
data, nor do they contain the rules of transformations. Rather, they are an
|
||
easy to use interface that use <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> building blocks.</p>
|
||
</section>
|
||
<section id="primitives">
|
||
<h3>Primitives<a class="headerlink" href="#primitives" title="Link to this heading">#</a></h3>
|
||
<p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> is part of the computation graph of an <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>. It
|
||
defines how to create output arrays given input arrays. Further, a
|
||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> has methods to run on the CPU or GPU and for function
|
||
transformations such as <code class="docutils literal notranslate"><span class="pre">vjp</span></code> and <code class="docutils literal notranslate"><span class="pre">jvp</span></code>. Let’s go back to our example to be
|
||
more concrete:</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Axpby</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="k">public</span><span class="w"> </span><span class="n">Primitive</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">public</span><span class="o">:</span>
|
||
<span class="w"> </span><span class="k">explicit</span><span class="w"> </span><span class="n">Axpby</span><span class="p">(</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">)</span>
|
||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">Primitive</span><span class="p">(</span><span class="n">stream</span><span class="p">),</span><span class="w"> </span><span class="n">alpha_</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span><span class="w"> </span><span class="n">beta_</span><span class="p">(</span><span class="n">beta</span><span class="p">){};</span>
|
||
|
||
<span class="w"> </span><span class="cm">/**</span>
|
||
<span class="cm"> * A primitive must know how to evaluate itself on the CPU/GPU</span>
|
||
<span class="cm"> * for the given inputs and populate the output array.</span>
|
||
<span class="cm"> *</span>
|
||
<span class="cm"> * To avoid unnecessary allocations, the evaluation function</span>
|
||
<span class="cm"> * is responsible for allocating space for the array.</span>
|
||
<span class="cm"> */</span>
|
||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_cpu</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">eval_gpu</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="cm">/** The Jacobian-vector product. */</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">jvp</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="cm">/** The vector-Jacobian product. */</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjp</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="cm">/**</span>
|
||
<span class="cm"> * The primitive must know how to vectorize itself across</span>
|
||
<span class="cm"> * the given axes. The output is a pair containing the array</span>
|
||
<span class="cm"> * representing the vectorized computation and the axis which</span>
|
||
<span class="cm"> * corresponds to the output vectorized dimension.</span>
|
||
<span class="cm"> */</span>
|
||
<span class="w"> </span><span class="k">virtual</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>></span><span class="w"> </span><span class="n">vmap</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="cm">/** Print the primitive. */</span>
|
||
<span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="nf">print</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">ostream</span><span class="o">&</span><span class="w"> </span><span class="n">os</span><span class="p">)</span><span class="w"> </span><span class="k">override</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="n">os</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"Axpby"</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
|
||
<span class="w"> </span><span class="cm">/** Equivalence check **/</span>
|
||
<span class="w"> </span><span class="kt">bool</span><span class="w"> </span><span class="nf">is_equivalent</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">Primitive</span><span class="o">&</span><span class="w"> </span><span class="n">other</span><span class="p">)</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="k">override</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="k">private</span><span class="o">:</span>
|
||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||
<span class="p">};</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> class derives from the base <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> class. The
|
||
<code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> treats <code class="docutils literal notranslate"><span class="pre">alpha</span></code> and <code class="docutils literal notranslate"><span class="pre">beta</span></code> as parameters. It then provides
|
||
implementations of how the output array is produced given the inputs through
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> and <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code>. It also provides rules
|
||
of transformations in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::jvp()</span></code>, <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vjp()</span></code>, and
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::vmap()</span></code>.</p>
|
||
</section>
|
||
<section id="using-the-primitive">
|
||
<h3>Using the Primitive<a class="headerlink" href="#using-the-primitive" title="Link to this heading">#</a></h3>
|
||
<p>Operations can use this <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> to add a new <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> to the
|
||
computation graph. An <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> can be constructed by providing its data
|
||
type, shape, the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code> that computes it, and the <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code>
|
||
inputs that are passed to the primitive.</p>
|
||
<p>Let’s reimplement our operation now in terms of our <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> primitive.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">array</span><span class="w"> </span><span class="nf">axpby</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="c1">// Input array y</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for x</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta</span><span class="p">,</span><span class="w"> </span><span class="c1">// Scaling factor for y</span>
|
||
<span class="w"> </span><span class="n">StreamOrDevice</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="cm">/* = {} */</span><span class="w"> </span><span class="c1">// Stream on which to schedule the operation</span>
|
||
<span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Promote dtypes between x and y as needed</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">promoted_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">(),</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">dtype</span><span class="p">());</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Upcast to float32 for non-floating point inputs x and y</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_dtype</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">issubdtype</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">promoted_dtype</span>
|
||
<span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">promote_types</span><span class="p">(</span><span class="n">promoted_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">float32</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Cast x and y up to the determined dtype (on the same stream s)</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_casted</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">astype</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Broadcast the shapes of x and y (on the same stream s)</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcast_arrays</span><span class="p">({</span><span class="n">x_casted</span><span class="p">,</span><span class="w"> </span><span class="n">y_casted</span><span class="p">},</span><span class="w"> </span><span class="n">s</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">out_shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">();</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Construct the array as the output of the Axpby primitive</span>
|
||
<span class="w"> </span><span class="c1">// with the broadcasted and upcasted arrays as inputs</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">array</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="cm">/* const std::vector<int>& shape = */</span><span class="w"> </span><span class="n">out_shape</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="cm">/* Dtype dtype = */</span><span class="w"> </span><span class="n">out_dtype</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="cm">/* std::unique_ptr<Primitive> primitive = */</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">make_shared</span><span class="o"><</span><span class="n">Axpby</span><span class="o">></span><span class="p">(</span><span class="n">to_stream</span><span class="p">(</span><span class="n">s</span><span class="p">),</span><span class="w"> </span><span class="n">alpha</span><span class="p">,</span><span class="w"> </span><span class="n">beta</span><span class="p">),</span>
|
||
<span class="w"> </span><span class="cm">/* const std::vector<array>& inputs = */</span><span class="w"> </span><span class="n">broadcasted_inputs</span><span class="p">);</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>This operation now handles the following:</p>
|
||
<ol class="arabic simple">
|
||
<li><p>Upcast inputs and resolve the output data type.</p></li>
|
||
<li><p>Broadcast the inputs and resolve the output shape.</p></li>
|
||
<li><p>Construct the primitive <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code> using the given stream, <code class="docutils literal notranslate"><span class="pre">alpha</span></code>, and <code class="docutils literal notranslate"><span class="pre">beta</span></code>.</p></li>
|
||
<li><p>Construct the output <code class="xref py py-class docutils literal notranslate"><span class="pre">array</span></code> using the primitive and the inputs.</p></li>
|
||
</ol>
|
||
</section>
|
||
</section>
|
||
<section id="implementing-the-primitive">
|
||
<h2>Implementing the Primitive<a class="headerlink" href="#implementing-the-primitive" title="Link to this heading">#</a></h2>
|
||
<p>No computation happens when we call the operation alone. The operation only
|
||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||
the execution of the computation graph, and calls <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code> or
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> depending on the stream/device specified by the user.</p>
|
||
<div class="admonition warning">
|
||
<p class="admonition-title">Warning</p>
|
||
<p>When <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_cpu()</span></code> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">Primitive::eval_gpu()</span></code> are called,
|
||
no memory has been allocated for the output array. It falls on the implementation
|
||
of these functions to allocate memory as needed.</p>
|
||
</div>
|
||
<section id="implementing-the-cpu-back-end">
|
||
<h3>Implementing the CPU Back-end<a class="headerlink" href="#implementing-the-cpu-back-end" title="Link to this heading">#</a></h3>
|
||
<p>Let’s start by implementing <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_cpu()</span></code>.</p>
|
||
<p>The method will go over each element of the output array, find the
|
||
corresponding input elements of <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code> and perform the operation
|
||
point-wise. This is captured in the templated function <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby_impl()</span></code>.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||
<span class="kt">void</span><span class="w"> </span><span class="n">axpby_impl</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="kt">float</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">Stream</span><span class="w"> </span><span class="n">stream</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">mx</span><span class="o">::</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Get the CPU command encoder and register input and output arrays</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">cpu</span><span class="o">::</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">stream</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">x</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">y</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">set_output_array</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Launch the CPU kernel</span>
|
||
<span class="w"> </span><span class="n">encoder</span><span class="p">.</span><span class="n">dispatch</span><span class="p">([</span><span class="n">x_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">y_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">out_ptr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">data</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="n">alpha_</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">beta_</span><span class="p">]()</span><span class="w"> </span><span class="p">{</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Cast alpha and beta to the relevant types</span>
|
||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha_</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta_</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Do the element-wise operation for each output</span>
|
||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">size</span><span class="p">;</span><span class="w"> </span><span class="n">out_idx</span><span class="o">++</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Map linear indices to offsets in x and y</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">out_idx</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// We allocate the output to be contiguous and regularly strided</span>
|
||
<span class="w"> </span><span class="c1">// (defaults to row major) and hence it doesn't need additional mapping</span>
|
||
<span class="w"> </span><span class="n">out_ptr</span><span class="p">[</span><span class="n">out_idx</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x_ptr</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y_ptr</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="w"> </span><span class="p">});</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Our implementation should work for all incoming floating point arrays.
|
||
Accordingly, we add dispatches for <code class="docutils literal notranslate"><span class="pre">float32</span></code>, <code class="docutils literal notranslate"><span class="pre">float16</span></code>, <code class="docutils literal notranslate"><span class="pre">bfloat16</span></code> and
|
||
<code class="docutils literal notranslate"><span class="pre">complex64</span></code>. We throw an error if we encounter an unexpected type.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_cpu</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Dispatch to the correct dtype</span>
|
||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float32</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">float16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">float16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">bfloat16_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="n">mx</span><span class="o">::</span><span class="n">complex64</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">axpby_impl</span><span class="o"><</span><span class="n">mx</span><span class="o">::</span><span class="n">complex64_t</span><span class="o">></span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">"Axpby is only supported for floating point types."</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Just this much is enough to run the operation <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> on a CPU stream! If
|
||
you do not plan on running the operation on the GPU or using transforms on
|
||
computation graphs that contain <code class="xref py py-class docutils literal notranslate"><span class="pre">Axpby</span></code>, you can stop implementing the
|
||
primitive here.</p>
|
||
</section>
|
||
<section id="implementing-the-gpu-back-end">
|
||
<h3>Implementing the GPU Back-end<a class="headerlink" href="#implementing-the-gpu-back-end" title="Link to this heading">#</a></h3>
|
||
<p>Apple silicon devices address their GPUs using the <a class="reference external" href="https://developer.apple.com/documentation/metal?language=objc">Metal</a> shading language, and
|
||
GPU kernels in MLX are written using Metal.</p>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Here are some helpful resources if you are new to Metal:</p>
|
||
<ul class="simple">
|
||
<li><p>A walkthrough of the metal compute pipeline: <a class="reference external" href="https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc">Metal Example</a></p></li>
|
||
<li><p>Documentation for metal shading language: <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Specification</a></p></li>
|
||
<li><p>Using metal from C++: <a class="reference external" href="https://developer.apple.com/metal/cpp/">Metal-cpp</a></p></li>
|
||
</ul>
|
||
</div>
|
||
<p>Let’s keep the GPU kernel simple. We will launch exactly as many threads as
|
||
there are elements in the output. Each thread will pick the element it needs
|
||
from <code class="docutils literal notranslate"><span class="pre">x</span></code> and <code class="docutils literal notranslate"><span class="pre">y</span></code>, do the point-wise operation, and update its assigned
|
||
element in the output.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o"><</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">></span>
|
||
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">0</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">1</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="n">T</span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">2</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">alpha</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">3</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">float</span><span class="o">&</span><span class="w"> </span><span class="n">beta</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">4</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">*</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">5</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int64_t</span><span class="o">*</span><span class="w"> </span><span class="n">x_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">6</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int64_t</span><span class="o">*</span><span class="w"> </span><span class="n">y_strides</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">7</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">constant</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">int</span><span class="o">&</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">8</span><span class="p">)]],</span>
|
||
<span class="w"> </span><span class="n">uint</span><span class="w"> </span><span class="n">index</span><span class="w"> </span><span class="p">[[</span><span class="n">thread_position_in_grid</span><span class="p">]])</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Convert linear indices to offsets in array</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">x_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">x_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">y_offset</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">elem_to_loc</span><span class="p">(</span><span class="n">index</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="p">,</span><span class="w"> </span><span class="n">y_strides</span><span class="p">,</span><span class="w"> </span><span class="n">ndim</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Do the operation and update the output</span>
|
||
<span class="w"> </span><span class="n">out</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="w"> </span><span class="o">=</span>
|
||
<span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="p">[</span><span class="n">x_offset</span><span class="p">]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="k">static_cast</span><span class="o"><</span><span class="n">T</span><span class="o">></span><span class="p">(</span><span class="n">beta</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">y</span><span class="p">[</span><span class="n">y_offset</span><span class="p">];</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We then need to instantiate this template for all floating point types and give
|
||
each instantiation a unique host name so we can identify it.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_float32"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="kt">float</span><span class="p">)</span>
|
||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_float16"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">float16_t</span><span class="p">)</span>
|
||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_bfloat16"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">bfloat16_t</span><span class="p">)</span>
|
||
<span class="n">instantiate_kernel</span><span class="p">(</span><span class="s">"axpby_general_complex64"</span><span class="p">,</span><span class="w"> </span><span class="n">axpby_general</span><span class="p">,</span><span class="w"> </span><span class="n">complex64_t</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||
and dispatch to the GPU are contained in <code class="xref py py-meth docutils literal notranslate"><span class="pre">Axpby::eval_gpu()</span></code> as shown
|
||
below.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Evaluate primitive on GPU */</span>
|
||
<span class="kt">void</span><span class="w"> </span><span class="nf">Axpby::eval_gpu</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">outputs</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Prepare inputs</span>
|
||
<span class="w"> </span><span class="n">assert</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Each primitive carries the stream it should execute on</span>
|
||
<span class="w"> </span><span class="c1">// and each stream carries its device identifiers</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">s</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">stream</span><span class="p">();</span>
|
||
<span class="w"> </span><span class="c1">// We get the needed metal device using the stream</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">d</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">device</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">device</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Allocate output memory</span>
|
||
<span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">allocator</span><span class="o">::</span><span class="n">malloc</span><span class="p">(</span><span class="n">out</span><span class="p">.</span><span class="n">nbytes</span><span class="p">()));</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Resolve name of kernel</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">ostringstream</span><span class="w"> </span><span class="n">kname</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="n">kname</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"axpby_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"general_"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">type_to_name</span><span class="p">(</span><span class="n">out</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Make sure the metal library is available</span>
|
||
<span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">register_library</span><span class="p">(</span><span class="s">"mlx_ext"</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Make a kernel from this metal library</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">kernel</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_kernel</span><span class="p">(</span><span class="n">kname</span><span class="p">.</span><span class="n">str</span><span class="p">(),</span><span class="w"> </span><span class="s">"mlx_ext"</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Prepare to encode kernel</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">compute_encoder</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">d</span><span class="p">.</span><span class="n">get_command_encoder</span><span class="p">(</span><span class="n">s</span><span class="p">.</span><span class="n">index</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_compute_pipeline_state</span><span class="p">(</span><span class="n">kernel</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Kernel parameters are registered with buffer indices corresponding to</span>
|
||
<span class="w"> </span><span class="c1">// those in the kernel declaration at axpby.metal</span>
|
||
<span class="w"> </span><span class="kt">int</span><span class="w"> </span><span class="n">ndim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">ndim</span><span class="p">();</span>
|
||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">nelem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">out</span><span class="p">.</span><span class="n">size</span><span class="p">();</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Encode input arrays to kernel</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_input_array</span><span class="p">(</span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Encode output arrays to kernel</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_output_array</span><span class="p">(</span><span class="n">out</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Encode alpha and beta</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="mi">3</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="mi">4</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Encode shape, strides and ndim</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_vector_bytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="p">(),</span><span class="w"> </span><span class="mi">5</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_vector_bytes</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span><span class="w"> </span><span class="mi">6</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">strides</span><span class="p">(),</span><span class="w"> </span><span class="mi">7</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">set_bytes</span><span class="p">(</span><span class="n">ndim</span><span class="p">,</span><span class="w"> </span><span class="mi">8</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// We launch 1 thread for each input and make sure that the number of</span>
|
||
<span class="w"> </span><span class="c1">// threads in any given threadgroup is not higher than the max allowed</span>
|
||
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">tgp_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">min</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="n">kernel</span><span class="o">-></span><span class="n">maxTotalThreadsPerThreadgroup</span><span class="p">());</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Fix the 3D size of each threadgroup (in terms of threads)</span>
|
||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">group_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">tgp_size</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Fix the 3D size of the launch grid (in terms of threads)</span>
|
||
<span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="w"> </span><span class="n">grid_dims</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">MTL</span><span class="o">::</span><span class="n">Size</span><span class="p">(</span><span class="n">nelem</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">);</span>
|
||
|
||
<span class="w"> </span><span class="c1">// Launch the grid with the given number of threads divided among</span>
|
||
<span class="w"> </span><span class="c1">// the given threadgroups</span>
|
||
<span class="w"> </span><span class="n">compute_encoder</span><span class="p">.</span><span class="n">dispatch_threads</span><span class="p">(</span><span class="n">grid_dims</span><span class="p">,</span><span class="w"> </span><span class="n">group_dims</span><span class="p">);</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now call the <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> operation on both the CPU and the GPU!</p>
|
||
<p>A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||
the active <code class="docutils literal notranslate"><span class="pre">command_buffer</span></code> and the <code class="docutils literal notranslate"><span class="pre">MTLCommandBuffer</span></code> to which it is
|
||
associated. We rely on <code class="xref py py-meth docutils literal notranslate"><span class="pre">d.get_command_encoder()</span></code> to give us the active
|
||
metal compute command encoder instead of building a new one and calling
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">compute_encoder->end_encoding()</span></code> at the end. MLX adds kernels (compute
|
||
pipelines) to the active command buffer until some specified limit is hit or
|
||
the command buffer needs to be flushed for synchronization.</p>
|
||
</section>
|
||
<section id="primitive-transforms">
|
||
<h3>Primitive Transforms<a class="headerlink" href="#primitive-transforms" title="Link to this heading">#</a></h3>
|
||
<p>Next, let’s add implementations for transformations in a <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.
|
||
These transformations can be built on top of other operations, including the
|
||
one we just defined:</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The Jacobian-vector product. */</span>
|
||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">jvp</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">tangents</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Forward mode diff that pushes along the tangents</span>
|
||
<span class="w"> </span><span class="c1">// The jvp transform on the primitive can be built with ops</span>
|
||
<span class="w"> </span><span class="c1">// that are scheduled on the same stream as the primitive</span>
|
||
|
||
<span class="w"> </span><span class="c1">// If argnums = {0}, we only push along x in which case the</span>
|
||
<span class="w"> </span><span class="c1">// jvp is just the tangent scaled by alpha</span>
|
||
<span class="w"> </span><span class="c1">// Similarly, if argnums = {1}, the jvp is just the tangent</span>
|
||
<span class="w"> </span><span class="c1">// scaled by beta</span>
|
||
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">argnums</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">argnums</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="w"> </span><span class="c1">// If argnums = {0, 1}, we take contributions from both</span>
|
||
<span class="w"> </span><span class="c1">// which gives us jvp = tangent_x * alpha + tangent_y * beta</span>
|
||
<span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="p">{</span><span class="n">axpby</span><span class="p">(</span><span class="n">tangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">tangents</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="w"> </span><span class="n">alpha_</span><span class="p">,</span><span class="w"> </span><span class="n">beta_</span><span class="p">,</span><span class="w"> </span><span class="n">stream</span><span class="p">())};</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** The vector-Jacobian product. */</span>
|
||
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vjp</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">primals</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">cotangents</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">argnums</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="cm">/* unused */</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="c1">// Reverse mode diff</span>
|
||
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">argnums</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">arg</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="w"> </span><span class="o">?</span><span class="w"> </span><span class="n">alpha_</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">beta_</span><span class="p">;</span>
|
||
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">scale_arr</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dtype</span><span class="p">());</span>
|
||
<span class="w"> </span><span class="n">vjps</span><span class="p">.</span><span class="n">push_back</span><span class="p">(</span><span class="n">multiply</span><span class="p">(</span><span class="n">scale_arr</span><span class="p">,</span><span class="w"> </span><span class="n">cotangents</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="n">stream</span><span class="p">()));</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">vjps</span><span class="p">;</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Note, a transformation does not need to be fully defined to start using
|
||
the <code class="xref py py-class docutils literal notranslate"><span class="pre">Primitive</span></code>.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="cm">/** Vectorize primitive along given axis */</span>
|
||
<span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">></span><span class="p">,</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>></span><span class="w"> </span><span class="n">Axpby</span><span class="o">::</span><span class="n">vmap</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">array</span><span class="o">>&</span><span class="w"> </span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>&</span><span class="w"> </span><span class="n">axes</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="k">throw</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">runtime_error</span><span class="p">(</span><span class="s">"[Axpby] vmap not implemented."</span><span class="p">);</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="building-and-binding">
|
||
<h2>Building and Binding<a class="headerlink" href="#building-and-binding" title="Link to this heading">#</a></h2>
|
||
<p>Let’s look at the overall directory structure first.</p>
|
||
<div class="line-block">
|
||
<div class="line">extensions</div>
|
||
<div class="line">├── axpby</div>
|
||
<div class="line">│ ├── axpby.cpp</div>
|
||
<div class="line">│ ├── axpby.h</div>
|
||
<div class="line">│ └── axpby.metal</div>
|
||
<div class="line">├── mlx_sample_extensions</div>
|
||
<div class="line">│ └── __init__.py</div>
|
||
<div class="line">├── bindings.cpp</div>
|
||
<div class="line">├── CMakeLists.txt</div>
|
||
<div class="line">└── setup.py</div>
|
||
</div>
|
||
<ul class="simple">
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/axpby/</span></code> defines the C++ extension library</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> sets out the structure for the
|
||
associated Python package</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/bindings.cpp</span></code> provides Python bindings for our operation</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/CMakeLists.txt</span></code> holds CMake rules to build the library and
|
||
Python bindings</p></li>
|
||
<li><p><code class="docutils literal notranslate"><span class="pre">extensions/setup.py</span></code> holds the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> rules to build and install
|
||
the Python package</p></li>
|
||
</ul>
|
||
<section id="binding-to-python">
|
||
<h3>Binding to Python<a class="headerlink" href="#binding-to-python" title="Link to this heading">#</a></h3>
|
||
<p>We use <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> to build a Python API for the C++ library. Since bindings for
|
||
components such as <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>, <a class="reference internal" href="../python/_autosummary/mlx.core.stream.html#mlx.core.stream" title="mlx.core.stream"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.stream</span></code></a>, etc. are
|
||
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple.</p>
|
||
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">NB_MODULE</span><span class="p">(</span><span class="n">_ext</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
|
||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"Sample extension for MLX"</span><span class="p">;</span>
|
||
|
||
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">"axpby"</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="o">&</span><span class="n">axpby</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="s">"x"</span><span class="n">_a</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="s">"y"</span><span class="n">_a</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="s">"alpha"</span><span class="n">_a</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="s">"beta"</span><span class="n">_a</span><span class="p">,</span>
|
||
<span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">kw_only</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="s">"stream"</span><span class="n">_a</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">nb</span><span class="o">::</span><span class="n">none</span><span class="p">(),</span>
|
||
<span class="w"> </span><span class="sa">R</span><span class="s">"</span><span class="dl">(</span>
|
||
<span class="s"> Scale and sum two vectors element-wise</span>
|
||
<span class="s"> ``z = alpha * x + beta * y``</span>
|
||
|
||
<span class="s"> Follows numpy style broadcasting between ``x`` and ``y``</span>
|
||
<span class="s"> Inputs are upcasted to floats if needed</span>
|
||
|
||
<span class="s"> Args:</span>
|
||
<span class="s"> x (array): Input array.</span>
|
||
<span class="s"> y (array): Input array.</span>
|
||
<span class="s"> alpha (float): Scaling factor for ``x``.</span>
|
||
<span class="s"> beta (float): Scaling factor for ``y``.</span>
|
||
|
||
<span class="s"> Returns:</span>
|
||
<span class="s"> array: ``alpha * x + beta * y``</span>
|
||
<span class="s"> </span><span class="dl">)</span><span class="s">"</span><span class="p">);</span>
|
||
<span class="w"> </span><span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Most of the complexity in the above example comes from additional bells and
|
||
whistles such as the literal names and doc-strings.</p>
|
||
<div class="admonition warning">
|
||
<p class="admonition-title">Warning</p>
|
||
<p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing
|
||
<code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx_sample_extensions</span></code> as defined by the nanobind module above to
|
||
ensure that the casters for <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> components like
|
||
<a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a> are available.</p>
|
||
</div>
|
||
</section>
|
||
<section id="building-with-cmake">
|
||
<span id="id1"></span><h3>Building with CMake<a class="headerlink" href="#building-with-cmake" title="Link to this heading">#</a></h3>
|
||
<p>Building the C++ extension library only requires that you <code class="docutils literal notranslate"><span class="pre">find_package(MLX</span>
|
||
<span class="pre">CONFIG)</span></code> and then link it to your library.</p>
|
||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Add library</span>
|
||
<span class="nb">add_library</span><span class="p">(</span><span class="s">mlx_ext</span><span class="p">)</span>
|
||
|
||
<span class="c"># Add sources</span>
|
||
<span class="nb">target_sources</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">mlx_ext</span>
|
||
<span class="w"> </span><span class="s">PUBLIC</span>
|
||
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/axpby/axpby.cpp</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c"># Add include headers</span>
|
||
<span class="nb">target_include_directories</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">mlx_ext</span><span class="w"> </span><span class="s">PUBLIC</span><span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c"># Link to mlx</span>
|
||
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">mlx_ext</span><span class="w"> </span><span class="s">PUBLIC</span><span class="w"> </span><span class="s">mlx</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We also need to build the attached Metal library. For convenience, we provide a
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">mlx_build_metallib()</span></code> function that builds a <code class="docutils literal notranslate"><span class="pre">.metallib</span></code> target given
|
||
sources, headers, destinations, etc. (defined in <code class="docutils literal notranslate"><span class="pre">cmake/extension.cmake</span></code> and
|
||
automatically imported with MLX package).</p>
|
||
<p>Here is what that looks like in practice:</p>
|
||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="c"># Build metallib</span>
|
||
<span class="nb">if</span><span class="p">(</span><span class="s">MLX_BUILD_METAL</span><span class="p">)</span>
|
||
|
||
<span class="nb">mlx_build_metallib</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">TARGET</span><span class="w"> </span><span class="s">mlx_ext_metallib</span>
|
||
<span class="w"> </span><span class="s">TITLE</span><span class="w"> </span><span class="s">mlx_ext</span>
|
||
<span class="w"> </span><span class="s">SOURCES</span><span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/axpby/axpby.metal</span>
|
||
<span class="w"> </span><span class="s">INCLUDE_DIRS</span><span class="w"> </span><span class="o">${</span><span class="nv">PROJECT_SOURCE_DIR</span><span class="o">}</span><span class="w"> </span><span class="o">${</span><span class="nv">MLX_INCLUDE_DIRS</span><span class="o">}</span>
|
||
<span class="w"> </span><span class="s">OUTPUT_DIRECTORY</span><span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_LIBRARY_OUTPUT_DIRECTORY</span><span class="o">}</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="nb">add_dependencies</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">mlx_ext</span>
|
||
<span class="w"> </span><span class="s">mlx_ext_metallib</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="nb">endif</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Finally, we build the <a class="reference external" href="https://nanobind.readthedocs.io/en/latest/">nanobind</a> bindings</p>
|
||
<div class="highlight-cmake notranslate"><div class="highlight"><pre><span></span><span class="nb">nanobind_add_module</span><span class="p">(</span>
|
||
<span class="w"> </span><span class="s">_ext</span>
|
||
<span class="w"> </span><span class="s">NB_STATIC</span><span class="w"> </span><span class="s">STABLE_ABI</span><span class="w"> </span><span class="s">LTO</span><span class="w"> </span><span class="s">NOMINSIZE</span>
|
||
<span class="w"> </span><span class="s">NB_DOMAIN</span><span class="w"> </span><span class="s">mlx</span>
|
||
<span class="w"> </span><span class="o">${</span><span class="nv">CMAKE_CURRENT_LIST_DIR</span><span class="o">}</span><span class="s">/bindings.cpp</span>
|
||
<span class="p">)</span>
|
||
<span class="nb">target_link_libraries</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">mlx_ext</span><span class="p">)</span>
|
||
|
||
<span class="nb">if</span><span class="p">(</span><span class="s">BUILD_SHARED_LIBS</span><span class="p">)</span>
|
||
<span class="w"> </span><span class="nb">target_link_options</span><span class="p">(</span><span class="s">_ext</span><span class="w"> </span><span class="s">PRIVATE</span><span class="w"> </span><span class="s">-Wl,-rpath,@loader_path</span><span class="p">)</span>
|
||
<span class="nb">endif</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="building-with-setuptools">
|
||
<h3>Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code><a class="headerlink" href="#building-with-setuptools" title="Link to this heading">#</a></h3>
|
||
<p>Once we have set out the CMake build rules as described above, we can use the
|
||
build utilities defined in <code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.extension</span></code>:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mlx</span><span class="w"> </span><span class="kn">import</span> <span class="n">extension</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">setuptools</span><span class="w"> </span><span class="kn">import</span> <span class="n">setup</span>
|
||
|
||
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span>
|
||
<span class="n">setup</span><span class="p">(</span>
|
||
<span class="n">name</span><span class="o">=</span><span class="s2">"mlx_sample_extensions"</span><span class="p">,</span>
|
||
<span class="n">version</span><span class="o">=</span><span class="s2">"0.0.0"</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Sample C++ and Metal extensions for MLX primitives."</span><span class="p">,</span>
|
||
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">extension</span><span class="o">.</span><span class="n">CMakeExtension</span><span class="p">(</span><span class="s2">"mlx_sample_extensions._ext"</span><span class="p">)],</span>
|
||
<span class="n">cmdclass</span><span class="o">=</span><span class="p">{</span><span class="s2">"build_ext"</span><span class="p">:</span> <span class="n">extension</span><span class="o">.</span><span class="n">CMakeBuild</span><span class="p">},</span>
|
||
<span class="n">packages</span><span class="o">=</span><span class="p">[</span><span class="s2">"mlx_sample_extensions"</span><span class="p">],</span>
|
||
<span class="n">package_data</span><span class="o">=</span><span class="p">{</span><span class="s2">"mlx_sample_extensions"</span><span class="p">:</span> <span class="p">[</span><span class="s2">"*.so"</span><span class="p">,</span> <span class="s2">"*.dylib"</span><span class="p">,</span> <span class="s2">"*.metallib"</span><span class="p">]},</span>
|
||
<span class="n">extras_require</span><span class="o">=</span><span class="p">{</span><span class="s2">"dev"</span><span class="p">:[]},</span>
|
||
<span class="n">zip_safe</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">python_requires</span><span class="o">=</span><span class="s2">">=3.8"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>We treat <code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> as the package directory
|
||
even though it only contains a <code class="docutils literal notranslate"><span class="pre">__init__.py</span></code> to ensure the following:</p>
|
||
<ul class="simple">
|
||
<li><p><code class="xref py py-mod docutils literal notranslate"><span class="pre">mlx.core</span></code> must be imported before importing <code class="xref py py-mod docutils literal notranslate"><span class="pre">_ext</span></code></p></li>
|
||
<li><p>The C++ extension library and the metal library are co-located with the python
|
||
bindings and copied together if the package is installed</p></li>
|
||
</ul>
|
||
</div>
|
||
<p>To build the package, first install the build dependencies with <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span>
|
||
<span class="pre">-r</span> <span class="pre">requirements.txt</span></code>. You can then build inplace for development using
|
||
<code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">setup.py</span> <span class="pre">build_ext</span> <span class="pre">-j8</span> <span class="pre">--inplace</span></code> (in <code class="docutils literal notranslate"><span class="pre">extensions/</span></code>)</p>
|
||
<p>This results in the directory structure:</p>
|
||
<div class="line-block">
|
||
<div class="line">extensions</div>
|
||
<div class="line">├── mlx_sample_extensions</div>
|
||
<div class="line">│ ├── __init__.py</div>
|
||
<div class="line">│ ├── libmlx_ext.dylib # C++ extension library</div>
|
||
<div class="line">│ ├── mlx_ext.metallib # Metal library</div>
|
||
<div class="line">│ └── _ext.cpython-3x-darwin.so # Python Binding</div>
|
||
<div class="line">…</div>
|
||
</div>
|
||
<p>When you try to install using the command <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">-m</span> <span class="pre">pip</span> <span class="pre">install</span> <span class="pre">.</span></code> (in
|
||
<code class="docutils literal notranslate"><span class="pre">extensions/</span></code>), the package will be installed with the same structure as
|
||
<code class="docutils literal notranslate"><span class="pre">extensions/mlx_sample_extensions</span></code> and the C++ and Metal library will be
|
||
copied along with the Python binding since they are specified as
|
||
<code class="docutils literal notranslate"><span class="pre">package_data</span></code>.</p>
|
||
</section>
|
||
</section>
|
||
<section id="usage">
|
||
<h2>Usage<a class="headerlink" href="#usage" title="Link to this heading">#</a></h2>
|
||
<p>After installing the extension as described above, you should be able to simply
|
||
import the Python package and play with it as you would any other MLX operation.</p>
|
||
<p>Let’s look at a simple script and its results:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
|
||
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
|
||
<span class="n">c</span> <span class="o">=</span> <span class="n">axpby</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">)</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c shape: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c dtype: </span><span class="si">{</span><span class="n">c</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"c is correct: </span><span class="si">{</span><span class="n">mx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">c</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mf">6.0</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Output:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">c</span> <span class="n">shape</span><span class="p">:</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
|
||
<span class="n">c</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">float32</span>
|
||
<span class="n">c</span> <span class="ow">is</span> <span class="n">correct</span><span class="p">:</span> <span class="kc">True</span>
|
||
</pre></div>
|
||
</div>
|
||
<section id="results">
|
||
<h3>Results<a class="headerlink" href="#results" title="Link to this heading">#</a></h3>
|
||
<p>Let’s run a quick benchmark and see how our new <code class="docutils literal notranslate"><span class="pre">axpby</span></code> operation compares
|
||
with the naive <code class="xref py py-meth docutils literal notranslate"><span class="pre">simple_axpby()</span></code> we first defined.</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mlx.core</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">mx</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">mlx_sample_extensions</span><span class="w"> </span><span class="kn">import</span> <span class="n">axpby</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">simple_axpby</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">y</span>
|
||
|
||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||
<span class="n">N</span> <span class="o">=</span> <span class="mi">4096</span>
|
||
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">))</span>
|
||
<span class="n">alpha</span> <span class="o">=</span> <span class="mf">4.0</span>
|
||
<span class="n">beta</span> <span class="o">=</span> <span class="mf">2.0</span>
|
||
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">bench</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
|
||
<span class="c1"># Warm up</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
|
||
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Timed run</span>
|
||
<span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
|
||
<span class="n">z</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
||
<span class="n">mx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||
<span class="n">e</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="mi">1000</span> <span class="o">*</span> <span class="p">(</span><span class="n">e</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span> <span class="o">/</span> <span class="mi">100</span>
|
||
|
||
<span class="n">simple_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">simple_axpby</span><span class="p">)</span>
|
||
<span class="n">custom_time</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">axpby</span><span class="p">)</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Simple axpby: </span><span class="si">{</span><span class="n">simple_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms | Custom axpby: </span><span class="si">{</span><span class="n">custom_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The results are <code class="docutils literal notranslate"><span class="pre">Simple</span> <span class="pre">axpby:</span> <span class="pre">1.559</span> <span class="pre">ms</span> <span class="pre">|</span> <span class="pre">Custom</span> <span class="pre">axpby:</span> <span class="pre">0.774</span> <span class="pre">ms</span></code>. We see
|
||
modest improvements right away!</p>
|
||
<p>This operation is now good to be used to build other operations, in
|
||
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like
|
||
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>.</p>
|
||
</section>
|
||
</section>
|
||
<section id="scripts">
|
||
<h2>Scripts<a class="headerlink" href="#scripts" title="Link to this heading">#</a></h2>
|
||
<div class="admonition-download-the-code admonition">
|
||
<p class="admonition-title">Download the code</p>
|
||
<p>The full example code is available in <a class="reference external" href="https://github.com/ml-explore/mlx/tree/main/examples/extensions/">mlx</a>.</p>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
|
||
|
||
</article>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<footer class="prev-next-footer d-print-none">
|
||
|
||
<div class="prev-next-area">
|
||
<a class="left-prev"
|
||
href="../cpp/ops.html"
|
||
title="previous page">
|
||
<i class="fa-solid fa-angle-left"></i>
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">previous</p>
|
||
<p class="prev-next-title">Operations</p>
|
||
</div>
|
||
</a>
|
||
<a class="right-next"
|
||
href="metal_debugger.html"
|
||
title="next page">
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">next</p>
|
||
<p class="prev-next-title">Metal Debugger</p>
|
||
</div>
|
||
<i class="fa-solid fa-angle-right"></i>
|
||
</a>
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
|
||
|
||
|
||
<div class="sidebar-secondary-item">
|
||
<div class="page-toc tocsection onthispage">
|
||
<i class="fa-solid fa-list"></i> Contents
|
||
</div>
|
||
<nav class="bd-toc-nav page-toc">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#introducing-the-example">Introducing the Example</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#operations-and-primitives">Operations and Primitives</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#operations">Operations</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitives">Primitives</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#using-the-primitive">Using the Primitive</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-primitive">Implementing the Primitive</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-cpu-back-end">Implementing the CPU Back-end</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#implementing-the-gpu-back-end">Implementing the GPU Back-end</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#primitive-transforms">Primitive Transforms</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#building-and-binding">Building and Binding</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#binding-to-python">Binding to Python</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#building-with-cmake">Building with CMake</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#building-with-setuptools">Building with <code class="docutils literal notranslate"><span class="pre">setuptools</span></code></a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#usage">Usage</a><ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#results">Results</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scripts">Scripts</a></li>
|
||
</ul>
|
||
</nav></div>
|
||
|
||
</div></div>
|
||
|
||
|
||
</div>
|
||
<footer class="bd-footer-content">
|
||
|
||
<div class="bd-footer-content__inner container">
|
||
|
||
<div class="footer-item">
|
||
|
||
<p class="component-author">
|
||
By MLX Contributors
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
|
||
<p class="copyright">
|
||
|
||
© Copyright 2023, Apple.
|
||
<br/>
|
||
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</footer>
|
||
|
||
|
||
</main>
|
||
</div>
|
||
</div>
|
||
|
||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<footer class="bd-footer">
|
||
</footer>
|
||
</body>
|
||
</html> |