mlx/docs/build/html/dev/custom_metal_kernels.html

1445 lines
139 KiB
HTML
Raw Normal View History

2024-08-24 03:14:53 +08:00
<!DOCTYPE html>
<html lang="en" data-content_root="../" >
<head>
<meta charset="utf-8" />
2024-10-15 23:12:17 +08:00
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
2024-08-24 03:14:53 +08:00
2025-06-04 09:03:47 +08:00
<title>Custom Metal Kernels &#8212; MLX 0.26.1 documentation</title>
2024-08-24 03:14:53 +08:00
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
2024-10-15 23:12:17 +08:00
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
2024-08-24 03:14:53 +08:00
</script>
<!-- Loaded before other Sphinx assets -->
2025-03-06 05:30:09 +08:00
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
2024-08-24 03:14:53 +08:00
2025-01-10 05:56:20 +08:00
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
2025-03-06 05:30:09 +08:00
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
2024-08-24 03:14:53 +08:00
<!-- Pre-loaded scripts that we'll load fully later -->
2025-03-06 05:30:09 +08:00
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
2024-08-24 03:14:53 +08:00
2025-06-04 09:03:47 +08:00
<script src="../_static/documentation_options.js?v=3724ff34"></script>
2024-10-15 23:12:17 +08:00
<script src="../_static/doctools.js?v=9a2dae69"></script>
2024-08-24 03:14:53 +08:00
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
2024-10-15 23:12:17 +08:00
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
2024-08-24 03:14:53 +08:00
<script>DOCUMENTATION_OPTIONS.pagename = 'dev/custom_metal_kernels';</script>
2024-10-31 11:00:19 +08:00
<link rel="icon" href="../_static/mlx_logo.png"/>
2024-08-24 03:14:53 +08:00
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
2025-01-10 05:56:20 +08:00
<link rel="next" title="Using MLX in C++" href="mlx_in_cpp.html" />
2024-08-24 03:14:53 +08:00
<link rel="prev" title="Metal Debugger" href="metal_debugger.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
2024-10-15 23:12:17 +08:00
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
2024-08-24 03:14:53 +08:00
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
2024-10-15 23:12:17 +08:00
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
2024-08-24 03:14:53 +08:00
2025-03-06 05:30:09 +08:00
<input type="checkbox"
class="sidebar-toggle"
id="pst-primary-sidebar-checkbox"/>
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
<input type="checkbox"
class="sidebar-toggle"
id="pst-secondary-sidebar-checkbox"/>
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
<div class="search-button__wrapper">
<div class="search-button__overlay"></div>
<div class="search-button__search-container">
2024-08-24 03:14:53 +08:00
<form class="bd-search d-flex align-items-center"
action="../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
2025-03-06 05:30:09 +08:00
id="search-input"
2024-08-24 03:14:53 +08:00
placeholder="Search..."
aria-label="Search..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
2025-03-06 05:30:09 +08:00
</form></div>
</div>
2024-10-15 23:12:17 +08:00
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
</header>
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
2024-10-15 23:12:17 +08:00
2025-03-06 05:30:09 +08:00
<div class="bd-sidebar-primary bd-sidebar">
2024-08-24 03:14:53 +08:00
<div class="sidebar-header-items sidebar-primary__section">
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
<a class="navbar-brand logo" href="../index.html">
2025-06-04 09:03:47 +08:00
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.1 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/>`);</script>
2024-08-24 03:14:53 +08:00
</a></div>
<div class="sidebar-primary-item">
2025-03-06 05:30:09 +08:00
<script>
document.write(`
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
`);
</script></div>
2024-08-24 03:14:53 +08:00
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
<div class="bd-toc-item navbar-nav active">
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Install</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../install.html">Build and Install</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Usage</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../usage/quick_start.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/lazy_evaluation.html">Lazy Evaluation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/unified_memory.html">Unified Memory</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/function_transforms.html">Function Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/compile.html">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/distributed.html">Distributed Communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l1"><a class="reference internal" href="../usage/export.html">Exporting Functions</a></li>
2024-08-24 03:14:53 +08:00
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../examples/linear_regression.html">Linear Regression</a></li>
<li class="toctree-l1"><a class="reference internal" href="../examples/mlp.html">Multi-Layer Perceptron</a></li>
<li class="toctree-l1"><a class="reference internal" href="../examples/llama-inference.html">LLM inference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Python API Reference</span></p>
<ul class="nav bd-sidenav">
2024-10-15 23:12:17 +08:00
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/array.html">Array</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.html">mlx.core.array</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.astype.html">mlx.core.array.astype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.at.html">mlx.core.array.at</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html">mlx.core.array.item</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.tolist.html">mlx.core.array.tolist</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.dtype.html">mlx.core.array.dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.itemsize.html">mlx.core.array.itemsize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.nbytes.html">mlx.core.array.nbytes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.ndim.html">mlx.core.array.ndim</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.shape.html">mlx.core.array.shape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.size.html">mlx.core.array.size</a></li>
2025-06-03 07:29:32 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.real.html">mlx.core.array.real</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.imag.html">mlx.core.array.imag</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.abs.html">mlx.core.array.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.all.html">mlx.core.array.all</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.any.html">mlx.core.array.any</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmax.html">mlx.core.array.argmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.argmin.html">mlx.core.array.argmin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.conj.html">mlx.core.array.conj</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cos.html">mlx.core.array.cos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummax.html">mlx.core.array.cummax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cummin.html">mlx.core.array.cummin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumprod.html">mlx.core.array.cumprod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.cumsum.html">mlx.core.array.cumsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diag.html">mlx.core.array.diag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.diagonal.html">mlx.core.array.diagonal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.exp.html">mlx.core.array.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.flatten.html">mlx.core.array.flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log.html">mlx.core.array.log</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log10.html">mlx.core.array.log10</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log1p.html">mlx.core.array.log1p</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.log2.html">mlx.core.array.log2</a></li>
2025-04-18 06:29:33 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logcumsumexp.html">mlx.core.array.logcumsumexp</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.logsumexp.html">mlx.core.array.logsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.max.html">mlx.core.array.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.mean.html">mlx.core.array.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.min.html">mlx.core.array.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.moveaxis.html">mlx.core.array.moveaxis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.prod.html">mlx.core.array.prod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reciprocal.html">mlx.core.array.reciprocal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.reshape.html">mlx.core.array.reshape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.round.html">mlx.core.array.round</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.rsqrt.html">mlx.core.array.rsqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sin.html">mlx.core.array.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.split.html">mlx.core.array.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sqrt.html">mlx.core.array.sqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.square.html">mlx.core.array.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.squeeze.html">mlx.core.array.squeeze</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.std.html">mlx.core.array.std</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.sum.html">mlx.core.array.sum</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.swapaxes.html">mlx.core.array.swapaxes</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.transpose.html">mlx.core.array.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.T.html">mlx.core.array.T</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.var.html">mlx.core.array.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array.view.html">mlx.core.array.view</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/data_types.html">Data Types</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Dtype.html">mlx.core.Dtype</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.DtypeCategory.html">mlx.core.DtypeCategory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.finfo.html">mlx.core.finfo</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/devices_and_streams.html">Devices and Streams</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.Device.html">mlx.core.Device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/stream_class.html">mlx.core.Stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_device.html">mlx.core.default_device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_device.html">mlx.core.set_default_device</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.default_stream.html">mlx.core.default_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.new_stream.html">mlx.core.new_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_default_stream.html">mlx.core.set_default_stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stream.html">mlx.core.stream</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.synchronize.html">mlx.core.synchronize</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/export.html">Export Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_function.html">mlx.core.export_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.import_function.html">mlx.core.import_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exporter.html">mlx.core.exporter</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.export_to_dot.html">mlx.core.export_to_dot</a></li>
</ul>
</details></li>
2024-10-15 23:12:17 +08:00
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/ops.html">Operations</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.abs.html">mlx.core.abs</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.add.html">mlx.core.add</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.addmm.html">mlx.core.addmm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.all.html">mlx.core.all</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.allclose.html">mlx.core.allclose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.any.html">mlx.core.any</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arange.html">mlx.core.arange</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccos.html">mlx.core.arccos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arccosh.html">mlx.core.arccosh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsin.html">mlx.core.arcsin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arcsinh.html">mlx.core.arcsinh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan.html">mlx.core.arctan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctan2.html">mlx.core.arctan2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.arctanh.html">mlx.core.arctanh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmax.html">mlx.core.argmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argmin.html">mlx.core.argmin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argpartition.html">mlx.core.argpartition</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.argsort.html">mlx.core.argsort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.array_equal.html">mlx.core.array_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.as_strided.html">mlx.core.as_strided</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_1d.html">mlx.core.atleast_1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_2d.html">mlx.core.atleast_2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.atleast_3d.html">mlx.core.atleast_3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_and.html">mlx.core.bitwise_and</a></li>
2025-02-15 05:44:39 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_invert.html">mlx.core.bitwise_invert</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_or.html">mlx.core.bitwise_or</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.bitwise_xor.html">mlx.core.bitwise_xor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.block_masked_mm.html">mlx.core.block_masked_mm</a></li>
2025-04-04 04:25:24 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_arrays.html">mlx.core.broadcast_arrays</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.broadcast_to.html">mlx.core.broadcast_to</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ceil.html">mlx.core.ceil</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clip.html">mlx.core.clip</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.concatenate.html">mlx.core.concatenate</a></li>
2025-04-04 04:25:24 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.contiguous.html">mlx.core.contiguous</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conj.html">mlx.core.conj</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conjugate.html">mlx.core.conjugate</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.convolve.html">mlx.core.convolve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv1d.html">mlx.core.conv1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv2d.html">mlx.core.conv2d</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv3d.html">mlx.core.conv3d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose1d.html">mlx.core.conv_transpose1d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose2d.html">mlx.core.conv_transpose2d</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_transpose3d.html">mlx.core.conv_transpose3d</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.conv_general.html">mlx.core.conv_general</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cos.html">mlx.core.cos</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cosh.html">mlx.core.cosh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummax.html">mlx.core.cummax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cummin.html">mlx.core.cummin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumprod.html">mlx.core.cumprod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.cumsum.html">mlx.core.cumsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.degrees.html">mlx.core.degrees</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.dequantize.html">mlx.core.dequantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diag.html">mlx.core.diag</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.diagonal.html">mlx.core.diagonal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divide.html">mlx.core.divide</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.divmod.html">mlx.core.divmod</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum.html">mlx.core.einsum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.einsum_path.html">mlx.core.einsum_path</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.equal.html">mlx.core.equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erf.html">mlx.core.erf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.erfinv.html">mlx.core.erfinv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.exp.html">mlx.core.exp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expm1.html">mlx.core.expm1</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.expand_dims.html">mlx.core.expand_dims</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eye.html">mlx.core.eye</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.flatten.html">mlx.core.flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor.html">mlx.core.floor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.floor_divide.html">mlx.core.floor_divide</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.full.html">mlx.core.full</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_mm.html">mlx.core.gather_mm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.gather_qmm.html">mlx.core.gather_qmm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater.html">mlx.core.greater</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.greater_equal.html">mlx.core.greater_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.hadamard_transform.html">mlx.core.hadamard_transform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.identity.html">mlx.core.identity</a></li>
2024-10-19 03:13:44 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.imag.html">mlx.core.imag</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.inner.html">mlx.core.inner</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isfinite.html">mlx.core.isfinite</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isclose.html">mlx.core.isclose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isinf.html">mlx.core.isinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isnan.html">mlx.core.isnan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isneginf.html">mlx.core.isneginf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.isposinf.html">mlx.core.isposinf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.issubdtype.html">mlx.core.issubdtype</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.kron.html">mlx.core.kron</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.left_shift.html">mlx.core.left_shift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less.html">mlx.core.less</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.less_equal.html">mlx.core.less_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linspace.html">mlx.core.linspace</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.load.html">mlx.core.load</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log.html">mlx.core.log</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log2.html">mlx.core.log2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log10.html">mlx.core.log10</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.log1p.html">mlx.core.log1p</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logaddexp.html">mlx.core.logaddexp</a></li>
2025-04-18 06:29:33 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logcumsumexp.html">mlx.core.logcumsumexp</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_not.html">mlx.core.logical_not</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_and.html">mlx.core.logical_and</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logical_or.html">mlx.core.logical_or</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.logsumexp.html">mlx.core.logsumexp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.matmul.html">mlx.core.matmul</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.max.html">mlx.core.max</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.maximum.html">mlx.core.maximum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.mean.html">mlx.core.mean</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.meshgrid.html">mlx.core.meshgrid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.min.html">mlx.core.min</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.minimum.html">mlx.core.minimum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.moveaxis.html">mlx.core.moveaxis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.multiply.html">mlx.core.multiply</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.nan_to_num.html">mlx.core.nan_to_num</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.negative.html">mlx.core.negative</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.not_equal.html">mlx.core.not_equal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones.html">mlx.core.ones</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.ones_like.html">mlx.core.ones_like</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.outer.html">mlx.core.outer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.partition.html">mlx.core.partition</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.pad.html">mlx.core.pad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.power.html">mlx.core.power</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.prod.html">mlx.core.prod</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.put_along_axis.html">mlx.core.put_along_axis</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantize.html">mlx.core.quantize</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.quantized_matmul.html">mlx.core.quantized_matmul</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.radians.html">mlx.core.radians</a></li>
2024-10-19 03:13:44 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.real.html">mlx.core.real</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reciprocal.html">mlx.core.reciprocal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.remainder.html">mlx.core.remainder</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.repeat.html">mlx.core.repeat</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reshape.html">mlx.core.reshape</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.right_shift.html">mlx.core.right_shift</a></li>
2024-10-15 04:10:48 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.roll.html">mlx.core.roll</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.round.html">mlx.core.round</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.rsqrt.html">mlx.core.rsqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save.html">mlx.core.save</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez.html">mlx.core.savez</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.savez_compressed.html">mlx.core.savez_compressed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_gguf.html">mlx.core.save_gguf</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.save_safetensors.html">mlx.core.save_safetensors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sigmoid.html">mlx.core.sigmoid</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sign.html">mlx.core.sign</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sin.html">mlx.core.sin</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sinh.html">mlx.core.sinh</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice.html">mlx.core.slice</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.slice_update.html">mlx.core.slice_update</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.softmax.html">mlx.core.softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sort.html">mlx.core.sort</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.split.html">mlx.core.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sqrt.html">mlx.core.sqrt</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.square.html">mlx.core.square</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.squeeze.html">mlx.core.squeeze</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stack.html">mlx.core.stack</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.std.html">mlx.core.std</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.stop_gradient.html">mlx.core.stop_gradient</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.subtract.html">mlx.core.subtract</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.sum.html">mlx.core.sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.swapaxes.html">mlx.core.swapaxes</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take.html">mlx.core.take</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.take_along_axis.html">mlx.core.take_along_axis</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tan.html">mlx.core.tan</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tanh.html">mlx.core.tanh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tensordot.html">mlx.core.tensordot</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tile.html">mlx.core.tile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.topk.html">mlx.core.topk</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.trace.html">mlx.core.trace</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.transpose.html">mlx.core.transpose</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tri.html">mlx.core.tri</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.tril.html">mlx.core.tril</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.triu.html">mlx.core.triu</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.unflatten.html">mlx.core.unflatten</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.var.html">mlx.core.var</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.view.html">mlx.core.view</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.where.html">mlx.core.where</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros.html">mlx.core.zeros</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.zeros_like.html">mlx.core.zeros_like</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/random.html">Random</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.bernoulli.html">mlx.core.random.bernoulli</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.categorical.html">mlx.core.random.categorical</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.gumbel.html">mlx.core.random.gumbel</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.key.html">mlx.core.random.key</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.normal.html">mlx.core.random.normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.multivariate_normal.html">mlx.core.random.multivariate_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.randint.html">mlx.core.random.randint</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.seed.html">mlx.core.random.seed</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.split.html">mlx.core.random.split</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.truncated_normal.html">mlx.core.random.truncated_normal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.uniform.html">mlx.core.random.uniform</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.laplace.html">mlx.core.random.laplace</a></li>
2024-10-15 04:10:48 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.random.permutation.html">mlx.core.random.permutation</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
2025-04-04 04:25:24 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.async_eval.html">mlx.core.async_eval</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html">mlx.core.custom_function</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.enable_compile.html">mlx.core.enable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.grad.html">mlx.core.grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html">mlx.core.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vjp.html">mlx.core.vjp</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html">mlx.core.vmap</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fast.html">Fast</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rms_norm.html">mlx.core.fast.rms_norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.layer_norm.html">mlx.core.fast.layer_norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.rope.html">mlx.core.fast.rope</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html">mlx.core.fast.scaled_dot_product_attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html">mlx.core.fast.metal_kernel</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/fft.html">FFT</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft.html">mlx.core.fft.fft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft.html">mlx.core.fft.ifft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fft2.html">mlx.core.fft.fft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifft2.html">mlx.core.fft.ifft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftn.html">mlx.core.fft.fftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftn.html">mlx.core.fft.ifftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft.html">mlx.core.fft.rfft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft.html">mlx.core.fft.irfft</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfft2.html">mlx.core.fft.rfft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfft2.html">mlx.core.fft.irfft2</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.rfftn.html">mlx.core.fft.rfftn</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.irfftn.html">mlx.core.fft.irfftn</a></li>
2025-05-10 05:42:00 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.fftshift.html">mlx.core.fft.fftshift</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.fft.ifftshift.html">mlx.core.fft.ifftshift</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/linalg.html">Linear Algebra</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.inv.html">mlx.core.linalg.inv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.tri_inv.html">mlx.core.linalg.tri_inv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.norm.html">mlx.core.linalg.norm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky.html">mlx.core.linalg.cholesky</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cholesky_inv.html">mlx.core.linalg.cholesky_inv</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.cross.html">mlx.core.linalg.cross</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.qr.html">mlx.core.linalg.qr</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.svd.html">mlx.core.linalg.svd</a></li>
2025-06-03 07:29:32 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvals.html">mlx.core.linalg.eigvals</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eig.html">mlx.core.linalg.eig</a></li>
2024-10-26 04:23:45 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigvalsh.html">mlx.core.linalg.eigvalsh</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.eigh.html">mlx.core.linalg.eigh</a></li>
2025-02-15 05:44:39 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu.html">mlx.core.linalg.lu</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.lu_factor.html">mlx.core.linalg.lu_factor</a></li>
2025-04-04 04:25:24 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.pinv.html">mlx.core.linalg.pinv</a></li>
2025-02-15 05:44:39 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve.html">mlx.core.linalg.solve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.linalg.solve_triangular.html">mlx.core.linalg.solve_triangular</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/metal.html">Metal</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.is_available.html">mlx.core.metal.is_available</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.device_info.html">mlx.core.metal.device_info</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.start_capture.html">mlx.core.metal.start_capture</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.metal.stop_capture.html">mlx.core.metal.stop_capture</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
2025-03-25 04:24:41 +08:00
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/memory_management.html">Memory Management</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_active_memory.html">mlx.core.get_active_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_peak_memory.html">mlx.core.get_peak_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.reset_peak_memory.html">mlx.core.reset_peak_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.get_cache_memory.html">mlx.core.get_cache_memory</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_memory_limit.html">mlx.core.set_memory_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_cache_limit.html">mlx.core.set_cache_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.set_wired_limit.html">mlx.core.set_wired_limit</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.clear_cache.html">mlx.core.clear_cache</a></li>
</ul>
</details></li>
2024-10-15 23:12:17 +08:00
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/nn.html">Neural Networks</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.quantize.html">mlx.nn.quantize</a></li>
2025-03-06 05:30:09 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.average_gradients.html">mlx.nn.average_gradients</a></li>
2024-10-15 23:12:17 +08:00
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.eval.html">mlx.nn.Module.eval</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html">mlx.nn.Module.filter_and_map</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.freeze.html">mlx.nn.Module.freeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html">mlx.nn.Module.leaf_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.load_weights.html">mlx.nn.Module.load_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.modules.html">mlx.nn.Module.modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.named_modules.html">mlx.nn.Module.named_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.parameters.html">mlx.nn.Module.parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.save_weights.html">mlx.nn.Module.save_weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.set_dtype.html">mlx.nn.Module.set_dtype</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.train.html">mlx.nn.Module.train</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html">mlx.nn.Module.trainable_parameters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.unfreeze.html">mlx.nn.Module.unfreeze</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update.html">mlx.nn.Module.update</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.update_modules.html">mlx.nn.Module.update_modules</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/layers.html">Layers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ALiBi.html">mlx.nn.ALiBi</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool1d.html">mlx.nn.AvgPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool2d.html">mlx.nn.AvgPool2d</a></li>
2024-11-23 04:24:16 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.AvgPool3d.html">mlx.nn.AvgPool3d</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.BatchNorm.html">mlx.nn.BatchNorm</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.CELU.html">mlx.nn.CELU</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv1d.html">mlx.nn.Conv1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv2d.html">mlx.nn.Conv2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Conv3d.html">mlx.nn.Conv3d</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html">mlx.nn.ConvTranspose1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html">mlx.nn.ConvTranspose2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html">mlx.nn.ConvTranspose3d</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout.html">mlx.nn.Dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout2d.html">mlx.nn.Dropout2d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Dropout3d.html">mlx.nn.Dropout3d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Embedding.html">mlx.nn.Embedding</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ELU.html">mlx.nn.ELU</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GELU.html">mlx.nn.GELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GLU.html">mlx.nn.GLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GroupNorm.html">mlx.nn.GroupNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.GRU.html">mlx.nn.GRU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardShrink.html">mlx.nn.HardShrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.HardTanh.html">mlx.nn.HardTanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Hardswish.html">mlx.nn.Hardswish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.InstanceNorm.html">mlx.nn.InstanceNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LayerNorm.html">mlx.nn.LayerNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LeakyReLU.html">mlx.nn.LeakyReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Linear.html">mlx.nn.Linear</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSigmoid.html">mlx.nn.LogSigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LogSoftmax.html">mlx.nn.LogSoftmax</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.LSTM.html">mlx.nn.LSTM</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool1d.html">mlx.nn.MaxPool1d</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool2d.html">mlx.nn.MaxPool2d</a></li>
2024-11-23 04:24:16 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MaxPool3d.html">mlx.nn.MaxPool3d</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Mish.html">mlx.nn.Mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html">mlx.nn.MultiHeadAttention</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.PReLU.html">mlx.nn.PReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedEmbedding.html">mlx.nn.QuantizedEmbedding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.QuantizedLinear.html">mlx.nn.QuantizedLinear</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RMSNorm.html">mlx.nn.RMSNorm</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU.html">mlx.nn.ReLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.ReLU6.html">mlx.nn.ReLU6</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RNN.html">mlx.nn.RNN</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.RoPE.html">mlx.nn.RoPE</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SELU.html">mlx.nn.SELU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sequential.html">mlx.nn.Sequential</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Sigmoid.html">mlx.nn.Sigmoid</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SiLU.html">mlx.nn.SiLU</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html">mlx.nn.SinusoidalPositionalEncoding</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmin.html">mlx.nn.Softmin</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softshrink.html">mlx.nn.Softshrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softsign.html">mlx.nn.Softsign</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softmax.html">mlx.nn.Softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Softplus.html">mlx.nn.Softplus</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Step.html">mlx.nn.Step</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Tanh.html">mlx.nn.Tanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Transformer.html">mlx.nn.Transformer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Upsample.html">mlx.nn.Upsample</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/functions.html">Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.elu.html">mlx.nn.elu</a></li>
2024-09-29 02:04:59 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.celu.html">mlx.nn.celu</a></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu.html">mlx.nn.gelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html">mlx.nn.gelu_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html">mlx.nn.gelu_fast_approx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.glu.html">mlx.nn.glu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_shrink.html">mlx.nn.hard_shrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hard_tanh.html">mlx.nn.hard_tanh</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.hardswish.html">mlx.nn.hardswish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.leaky_relu.html">mlx.nn.leaky_relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_sigmoid.html">mlx.nn.log_sigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.log_softmax.html">mlx.nn.log_softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.mish.html">mlx.nn.mish</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.prelu.html">mlx.nn.prelu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu.html">mlx.nn.relu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.relu6.html">mlx.nn.relu6</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.selu.html">mlx.nn.selu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.sigmoid.html">mlx.nn.sigmoid</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.silu.html">mlx.nn.silu</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmax.html">mlx.nn.softmax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softmin.html">mlx.nn.softmin</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softplus.html">mlx.nn.softplus</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.softshrink.html">mlx.nn.softshrink</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.step.html">mlx.nn.step</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.tanh.html">mlx.nn.tanh</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/losses.html">Loss Functions</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html">mlx.nn.losses.binary_cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html">mlx.nn.losses.cosine_similarity_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html">mlx.nn.losses.cross_entropy</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html">mlx.nn.losses.gaussian_nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html">mlx.nn.losses.hinge_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html">mlx.nn.losses.huber_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html">mlx.nn.losses.margin_ranking_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html">mlx.nn.losses.triplet_loss</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/init.html">Initializers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.constant.html">mlx.nn.init.constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.normal.html">mlx.nn.init.normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.uniform.html">mlx.nn.init.uniform</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.identity.html">mlx.nn.init.identity</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_normal.html">mlx.nn.init.glorot_normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.glorot_uniform.html">mlx.nn.init.glorot_uniform</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_normal.html">mlx.nn.init.he_normal</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.init.he_uniform.html">mlx.nn.init.he_uniform</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/optimizer.html">Optimizer</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.html">mlx.optimizers.Optimizer.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html">mlx.optimizers.Optimizer.apply_gradients</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.html">mlx.optimizers.Optimizer.init</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.html">mlx.optimizers.Optimizer.update</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/common_optimizers.html">Common Optimizers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adafactor.html">mlx.optimizers.Adafactor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdaDelta.html">mlx.optimizers.AdaDelta</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adam.html">mlx.optimizers.Adam</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.AdamW.html">mlx.optimizers.AdamW</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Adamax.html">mlx.optimizers.Adamax</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
2025-04-18 06:29:33 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.MultiOptimizer.html">mlx.optimizers.MultiOptimizer</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizers/schedulers.html">Schedulers</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.cosine_decay.html">mlx.optimizers.cosine_decay</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.exponential_decay.html">mlx.optimizers.exponential_decay</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.join_schedules.html">mlx.optimizers.join_schedules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.linear_schedule.html">mlx.optimizers.linear_schedule</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/optimizers/_autosummary/mlx.optimizers.step_decay.html">mlx.optimizers.step_decay</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.clip_grad_norm.html">mlx.optimizers.clip_grad_norm</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/distributed.html">Distributed Communication</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.Group.html">mlx.core.distributed.Group</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.is_available.html">mlx.core.distributed.is_available</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.init.html">mlx.core.distributed.init</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_sum.html">mlx.core.distributed.all_sum</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.all_gather.html">mlx.core.distributed.all_gather</a></li>
2024-09-18 03:06:14 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.send.html">mlx.core.distributed.send</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv.html">mlx.core.distributed.recv</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.distributed.recv_like.html">mlx.core.distributed.recv_like</a></li>
2024-08-24 03:14:53 +08:00
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
2024-08-24 03:14:53 +08:00
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map_with_path.html">mlx.utils.tree_map_with_path</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_reduce.html">mlx.utils.tree_reduce</a></li>
</ul>
2024-10-15 23:12:17 +08:00
</details></li>
2024-08-24 03:14:53 +08:00
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../cpp/ops.html">Operations</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Further Reading</span></p>
<ul class="current nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="extensions.html">Custom Extensions in MLX</a></li>
<li class="toctree-l1"><a class="reference internal" href="metal_debugger.html">Metal Debugger</a></li>
<li class="toctree-l1 current active"><a class="current reference internal" href="#">Custom Metal Kernels</a></li>
2025-01-10 05:56:20 +08:00
<li class="toctree-l1"><a class="reference internal" href="mlx_in_cpp.html">Using MLX in C++</a></li>
2024-08-24 03:14:53 +08:00
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
2025-03-06 05:30:09 +08:00
<div id="rtd-footer-container"></div>
2024-08-24 03:14:53 +08:00
</div>
2024-10-15 23:12:17 +08:00
<main id="main-content" class="bd-main" role="main">
2024-08-24 03:14:53 +08:00
<div class="sbt-scroll-pixel-helper"></div>
<div class="bd-content">
<div class="bd-article-container">
2024-10-15 23:12:17 +08:00
<div class="bd-header-article d-print-none">
2024-08-24 03:14:53 +08:00
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
2024-10-15 23:12:17 +08:00
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
2024-08-24 03:14:53 +08:00
<span class="fa-solid fa-bars"></span>
2024-10-15 23:12:17 +08:00
</button></div>
2024-08-24 03:14:53 +08:00
</div>
<div class="header-article-items__end">
<div class="header-article-item">
<div class="article-header-buttons">
<a href="https://github.com/ml-explore/mlx" target="_blank"
class="btn btn-sm btn-source-repository-button"
title="Source repository"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fab fa-github"></i>
</span>
</a>
<div class="dropdown dropdown-download-buttons">
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
<i class="fas fa-download"></i>
</button>
<ul class="dropdown-menu">
<li><a href="../_sources/dev/custom_metal_kernels.rst" target="_blank"
class="btn btn-sm btn-download-source-button dropdown-item"
title="Download source file"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file"></i>
</span>
<span class="btn__text-container">.rst</span>
</a>
</li>
<li>
<button onclick="window.print()"
class="btn btn-sm btn-download-pdf-button dropdown-item"
title="Print to PDF"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file-pdf"></i>
</span>
<span class="btn__text-container">.pdf</span>
</button>
</li>
</ul>
</div>
<button onclick="toggleFullScreen()"
class="btn btn-sm btn-fullscreen-button"
title="Fullscreen mode"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-expand"></i>
</span>
</button>
2025-03-06 05:30:09 +08:00
<script>
document.write(`
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
</button>
`);
</script>
2024-08-24 03:14:53 +08:00
2025-03-06 05:30:09 +08:00
<script>
document.write(`
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
2024-08-24 03:14:53 +08:00
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
2025-03-06 05:30:09 +08:00
</button>
`);
</script>
2024-10-15 23:12:17 +08:00
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
2024-08-24 03:14:53 +08:00
<span class="fa-solid fa-list"></span>
2024-10-15 23:12:17 +08:00
</button>
2024-08-24 03:14:53 +08:00
</div></div>
</div>
</div>
</div>
<div id="jb-print-docs-body" class="onlyprint">
<h1>Custom Metal Kernels</h1>
<!-- Table of contents -->
<div id="print-main-content">
<div id="jb-print-toc">
<div>
<h2> Contents </h2>
</div>
<nav aria-label="Page">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#simple-example">Simple Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-shape-strides">Using Shape/Strides</a></li>
2024-09-18 03:06:14 +08:00
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#complex-example">Complex Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#grid-sample-vjp">Grid Sample VJP</a></li>
2024-08-24 03:14:53 +08:00
</ul>
</nav>
</div>
</div>
</div>
<div id="searchbox"></div>
2024-10-15 23:12:17 +08:00
<article class="bd-article">
2024-08-24 03:14:53 +08:00
<section id="custom-metal-kernels">
2024-11-06 03:54:16 +08:00
<span id="id1"></span><h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
2024-08-24 03:14:53 +08:00
<p>MLX supports writing custom Metal kernels through the Python and C++ APIs.</p>
<section id="simple-example">
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
<p>Lets write a custom kernel that computes <code class="docutils literal notranslate"><span class="pre">exp</span></code> elementwise:</p>
2025-01-10 05:56:20 +08:00
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
2024-08-24 03:14:53 +08:00
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> T tmp = inp[elem];</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp&quot;</span><span class="p">,</span>
2024-09-18 03:06:14 +08:00
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
2024-08-24 03:14:53 +08:00
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
2024-09-18 03:06:14 +08:00
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
2024-08-24 03:14:53 +08:00
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
2024-09-18 03:06:14 +08:00
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
2024-08-24 03:14:53 +08:00
<span class="p">)</span>
2024-09-18 03:06:14 +08:00
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
2024-08-24 03:14:53 +08:00
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">mx</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
</pre></div>
</div>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>We are only required to pass the body of the Metal kernel in <code class="docutils literal notranslate"><span class="pre">source</span></code>.</p>
</div>
<p>The full function signature will be generated using:</p>
<ul class="simple">
<li><dl class="simple">
2024-09-18 03:06:14 +08:00
<dt>The shapes/dtypes of <code class="docutils literal notranslate"><span class="pre">inputs</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">a</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code> and we pass it with the key <code class="docutils literal notranslate"><span class="pre">inp</span></code>
2024-08-24 03:14:53 +08:00
so we will add <code class="docutils literal notranslate"><span class="pre">const</span> <span class="pre">device</span> <span class="pre">float16_t*</span> <span class="pre">inp</span></code> to the signature.
2024-09-18 03:06:14 +08:00
<code class="docutils literal notranslate"><span class="pre">inp_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">inp_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">inp_ndim</span></code> are also added for convenience if they are present
in <code class="docutils literal notranslate"><span class="pre">source</span></code>.</p>
2024-08-24 03:14:53 +08:00
</dd>
</dl>
</li>
<li><dl class="simple">
2024-09-18 03:06:14 +08:00
<dt>The list of <code class="docutils literal notranslate"><span class="pre">output_dtypes</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">out</span></code> is an <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> of type <code class="docutils literal notranslate"><span class="pre">mx.float16</span></code>
2024-08-24 03:14:53 +08:00
so we add <code class="docutils literal notranslate"><span class="pre">device</span> <span class="pre">float16_t*</span> <span class="pre">out</span></code>.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
2024-09-18 03:06:14 +08:00
<dt>Template parameters passed using <code class="docutils literal notranslate"><span class="pre">template</span></code></dt><dd><p>In the above, <code class="docutils literal notranslate"><span class="pre">template=[(&quot;T&quot;,</span> <span class="pre">mx.float32)]</span></code> adds a template of <code class="docutils literal notranslate"><span class="pre">template</span> <span class="pre">&lt;typename</span> <span class="pre">T&gt;</span></code> to the function
2024-08-24 03:14:53 +08:00
and instantiates the template with <code class="docutils literal notranslate"><span class="pre">custom_kernel_myexp_float&lt;float&gt;</span></code>.
Template parameters can be <code class="docutils literal notranslate"><span class="pre">mx.core.Dtype</span></code>, <code class="docutils literal notranslate"><span class="pre">int</span></code> or <code class="docutils literal notranslate"><span class="pre">bool</span></code>.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
<dt>Metal attributes used in <code class="docutils literal notranslate"><span class="pre">source</span></code> such as <code class="docutils literal notranslate"><span class="pre">[[thread_position_in_grid]]</span></code></dt><dd><p>These will be added as function arguments.
All the attributes defined in Table 5.8 of the <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Shading Language Specification</a> are supported.</p>
</dd>
</dl>
</li>
</ul>
<p>Putting this all together, the generated function signature for <code class="docutils literal notranslate"><span class="pre">myexp</span></code> is as follows:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="k">template</span><span class="w"> </span><span class="o">&lt;</span><span class="k">typename</span><span class="w"> </span><span class="nc">T</span><span class="o">&gt;</span>
<span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="kt">void</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="p">(</span>
<span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="n">float16_t</span><span class="o">*</span><span class="w"> </span><span class="n">inp</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">0</span><span class="p">)]],</span>
<span class="w"> </span><span class="n">device</span><span class="w"> </span><span class="n">float16_t</span><span class="o">*</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">[[</span><span class="n">buffer</span><span class="p">(</span><span class="mi">1</span><span class="p">)]],</span>
<span class="w"> </span><span class="n">uint3</span><span class="w"> </span><span class="n">thread_position_in_grid</span><span class="w"> </span><span class="p">[[</span><span class="n">thread_position_in_grid</span><span class="p">]])</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">uint</span><span class="w"> </span><span class="n">elem</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">thread_position_in_grid</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
<span class="w"> </span><span class="n">T</span><span class="w"> </span><span class="n">tmp</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">inp</span><span class="p">[</span><span class="n">elem</span><span class="p">];</span>
<span class="w"> </span><span class="n">out</span><span class="p">[</span><span class="n">elem</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">metal</span><span class="o">::</span><span class="n">exp</span><span class="p">(</span><span class="n">tmp</span><span class="p">);</span>
<span class="p">}</span>
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">&quot;custom_kernel_myexp_float&quot;</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">;</span>
</pre></div>
</div>
2024-11-06 03:54:16 +08:00
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
2024-09-18 03:06:14 +08:00
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
2024-08-24 03:14:53 +08:00
</section>
<section id="using-shape-strides">
<h2>Using Shape/Strides<a class="headerlink" href="#using-shape-strides" title="Link to this heading">#</a></h2>
<p><code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> supports an argument <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code> which is <code class="docutils literal notranslate"><span class="pre">True</span></code> by default.
This will copy the <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we dont have to worry about gaps or the ordering of the dims
when indexing.</p>
<p>If we want to avoid this copy, <code class="docutils literal notranslate"><span class="pre">metal_kernel</span></code> automatically passes <code class="docutils literal notranslate"><span class="pre">a_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">a_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">a_ndim</span></code> for each
input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are present in <code class="docutils literal notranslate"><span class="pre">source</span></code>.
We can then use MLXs built in indexing utils to fetch the right elements for each thread.</p>
<p>Lets convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
2025-01-10 05:56:20 +08:00
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
2024-08-24 03:14:53 +08:00
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
<span class="s2"> uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);</span>
<span class="s2"> T tmp = inp[loc];</span>
<span class="s2"> // Output arrays are always row contiguous</span>
<span class="s2"> out[elem] = metal::exp(tmp);</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;myexp_strided&quot;</span><span class="p">,</span>
2024-09-18 03:06:14 +08:00
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;inp&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
2024-08-24 03:14:53 +08:00
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
2024-09-18 03:06:14 +08:00
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
2024-08-24 03:14:53 +08:00
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
2024-09-18 03:06:14 +08:00
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
2024-08-24 03:14:53 +08:00
<span class="n">ensure_row_contiguous</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">)</span>
2024-09-18 03:06:14 +08:00
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
2024-08-24 03:14:53 +08:00
<span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="c1"># make non-contiguous</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">a</span><span class="p">[::</span><span class="mi">2</span><span class="p">]</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">mx</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
</pre></div>
</div>
</section>
2024-09-18 03:06:14 +08:00
<section id="complex-example">
<h2>Complex Example<a class="headerlink" href="#complex-example" title="Link to this heading">#</a></h2>
<p>Lets implement a more complex example: <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> in <code class="docutils literal notranslate"><span class="pre">&quot;bilinear&quot;</span></code> mode.</p>
<p>Well start with the following MLX implementation using standard ops:</p>
2025-01-10 05:56:20 +08:00
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_ref</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
2024-09-18 03:06:14 +08:00
<span class="n">N</span><span class="p">,</span> <span class="n">H_in</span><span class="p">,</span> <span class="n">W_in</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ix</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">iy</span> <span class="o">=</span> <span class="p">((</span><span class="n">grid</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">ix_nw</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">ix</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">iy_nw</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">iy</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">ix_ne</span> <span class="o">=</span> <span class="n">ix_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">iy_ne</span> <span class="o">=</span> <span class="n">iy_nw</span>
<span class="n">ix_sw</span> <span class="o">=</span> <span class="n">ix_nw</span>
<span class="n">iy_sw</span> <span class="o">=</span> <span class="n">iy_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">ix_se</span> <span class="o">=</span> <span class="n">ix_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">iy_se</span> <span class="o">=</span> <span class="n">iy_nw</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">nw</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">-</span> <span class="n">ix</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">-</span> <span class="n">iy</span><span class="p">)</span>
<span class="n">ne</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix</span> <span class="o">-</span> <span class="n">ix_sw</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">-</span> <span class="n">iy</span><span class="p">)</span>
<span class="n">sw</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">-</span> <span class="n">ix</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy</span> <span class="o">-</span> <span class="n">iy_ne</span><span class="p">)</span>
<span class="n">se</span> <span class="o">=</span> <span class="p">(</span><span class="n">ix</span> <span class="o">-</span> <span class="n">ix_nw</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">iy</span> <span class="o">-</span> <span class="n">iy_nw</span><span class="p">)</span>
<span class="n">I_nw</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_nw</span><span class="p">,</span> <span class="n">ix_nw</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_ne</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_ne</span><span class="p">,</span> <span class="n">ix_ne</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_sw</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_sw</span><span class="p">,</span> <span class="n">ix_sw</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">I_se</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">iy_se</span><span class="p">,</span> <span class="n">ix_se</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">mask_nw</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_nw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_nw</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_nw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_nw</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_ne</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_ne</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_ne</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_ne</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_sw</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_sw</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_sw</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_sw</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">mask_se</span> <span class="o">=</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">iy_se</span> <span class="o">&lt;=</span> <span class="n">H_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ix_se</span> <span class="o">&lt;=</span> <span class="n">W_in</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">I_nw</span> <span class="o">*=</span> <span class="n">mask_nw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_ne</span> <span class="o">*=</span> <span class="n">mask_ne</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_sw</span> <span class="o">*=</span> <span class="n">mask_sw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">I_se</span> <span class="o">*=</span> <span class="n">mask_se</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">nw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_nw</span> <span class="o">+</span> <span class="n">ne</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_ne</span> <span class="o">+</span> <span class="n">sw</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_sw</span> <span class="o">+</span> <span class="n">se</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">I_se</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<p>Now lets use <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code> together with <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code>
to write a fast GPU kernel for both the forward and backward passes.</p>
<p>First well implement the forward pass as a fused kernel:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
2025-01-10 05:56:20 +08:00
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
2024-09-18 03:06:14 +08:00
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&quot;`x` must be 4D.&quot;</span>
<span class="k">assert</span> <span class="n">grid</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">&quot;`grid` must be 4D.&quot;</span>
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
<span class="s2"> I_nw = iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1 ? I_nw : 0;</span>
<span class="s2"> I_ne = iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1 ? I_ne : 0;</span>
<span class="s2"> I_sw = iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1 ? I_sw : 0;</span>
<span class="s2"> I_se = iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1 ? I_se : 0;</span>
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;out&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">out_shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">out_shape</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</pre></div>
</div>
<p>For a reasonably sized input such as:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">x</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">64</span><span class="p">)</span>
<span class="n">grid</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</pre></div>
</div>
<p>On an M1 Max, we see a big performance improvement:</p>
<p><code class="docutils literal notranslate"><span class="pre">55.7ms</span> <span class="pre">-&gt;</span> <span class="pre">6.7ms</span> <span class="pre">=&gt;</span> <span class="pre">8x</span> <span class="pre">speed</span> <span class="pre">up</span></code></p>
</section>
<section id="grid-sample-vjp">
<h2>Grid Sample VJP<a class="headerlink" href="#grid-sample-vjp" title="Link to this heading">#</a></h2>
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code>, we can now define
its custom vjp transform so MLX can differentiate it.</p>
<p>The backwards pass requires atomically updating <code class="docutils literal notranslate"><span class="pre">x_grad</span></code>/<code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> and so
requires a few extra <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> features:</p>
<ul class="simple">
<li><dl class="simple">
<dt><code class="docutils literal notranslate"><span class="pre">init_value=0</span></code></dt><dd><p>Initialize all of the kernels outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.</p>
</dd>
</dl>
</li>
<li><dl class="simple">
<dt><code class="docutils literal notranslate"><span class="pre">atomic_outputs=True</span></code></dt><dd><p>Designate all of the kernel outputs as <code class="docutils literal notranslate"><span class="pre">atomic</span></code> in the function signature.
This means we can use Metals <code class="docutils literal notranslate"><span class="pre">atomic</span></code> features to simultaneously update the <code class="docutils literal notranslate"><span class="pre">x_grad</span></code> and <code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> arrays from multiple threadgroups.
See section 6.15 of the <a class="reference external" href="https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf">Metal Shading Language Specification</a> for more details.</p>
</dd>
</dl>
</li>
</ul>
<p>We can then implement the backwards pass as follows:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
2025-01-10 05:56:20 +08:00
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
2024-09-18 03:06:14 +08:00
<span class="n">x</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">primals</span>
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_</span><span class="p">,</span> <span class="n">gN</span><span class="p">,</span> <span class="n">gM</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">&quot;Last dim of `grid` must be size 2.&quot;</span>
<span class="n">source</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
<span class="s2"> int H = x_shape[1];</span>
<span class="s2"> int W = x_shape[2];</span>
<span class="s2"> int C = x_shape[3];</span>
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
<span class="s2"> int gH = grid_shape[1];</span>
<span class="s2"> int gW = grid_shape[2];</span>
<span class="s2"> int w_stride = C;</span>
<span class="s2"> int h_stride = W * w_stride;</span>
<span class="s2"> int b_stride = H * h_stride;</span>
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
<span class="s2"> int ix_nw = floor(ix);</span>
<span class="s2"> int iy_nw = floor(iy);</span>
<span class="s2"> int ix_ne = ix_nw + 1;</span>
<span class="s2"> int iy_ne = iy_nw;</span>
<span class="s2"> int ix_sw = ix_nw;</span>
<span class="s2"> int iy_sw = iy_nw + 1;</span>
<span class="s2"> int ix_se = ix_nw + 1;</span>
<span class="s2"> int iy_se = iy_nw + 1;</span>
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
<span class="s2"> int channel_idx = elem % C_padded;</span>
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
<span class="s2"> T gix = T(0);</span>
<span class="s2"> T giy = T(0);</span>
<span class="s2"> if (channel_idx &lt; C) {</span>
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
<span class="s2"> T cot = cotangent[cot_index];</span>
<span class="s2"> if (iy_nw &gt;= 0 &amp;&amp; iy_nw &lt;= H - 1 &amp;&amp; ix_nw &gt;= 0 &amp;&amp; ix_nw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], nw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_nw = x[offset];</span>
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_ne &gt;= 0 &amp;&amp; iy_ne &lt;= H - 1 &amp;&amp; ix_ne &gt;= 0 &amp;&amp; ix_ne &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], ne * cot, memory_order_relaxed);</span>
<span class="s2"> T I_ne = x[offset];</span>
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_sw &gt;= 0 &amp;&amp; iy_sw &lt;= H - 1 &amp;&amp; ix_sw &gt;= 0 &amp;&amp; ix_sw &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], sw * cot, memory_order_relaxed);</span>
<span class="s2"> T I_sw = x[offset];</span>
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> if (iy_se &gt;= 0 &amp;&amp; iy_se &lt;= H - 1 &amp;&amp; ix_se &gt;= 0 &amp;&amp; ix_se &lt;= W - 1) {</span>
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;x_grad[offset], se * cot, memory_order_relaxed);</span>
<span class="s2"> T I_se = x[offset];</span>
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
<span class="s2"> }</span>
<span class="s2"> }</span>
<span class="s2"> T gix_mult = W / 2;</span>
<span class="s2"> T giy_mult = H / 2;</span>
<span class="s2"> // Reduce across each simdgroup first.</span>
<span class="s2"> // This is much faster than relying purely on atomics.</span>
<span class="s2"> gix = simd_sum(gix);</span>
<span class="s2"> giy = simd_sum(giy);</span>
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
<span class="s2"> atomic_fetch_add_explicit(&amp;grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
<span class="s2"> }</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;grid_sample_grad&quot;</span><span class="p">,</span>
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="s2">&quot;grid&quot;</span><span class="p">,</span> <span class="s2">&quot;cotangent&quot;</span><span class="p">],</span>
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;x_grad&quot;</span><span class="p">,</span> <span class="s2">&quot;grid_grad&quot;</span><span class="p">],</span>
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># pad the output channels to simd group size</span>
<span class="c1"># so that our `simd_sum`s don&#39;t overlap.</span>
<span class="n">simdgroup_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">C_padded</span> <span class="o">=</span> <span class="p">(</span><span class="n">C</span> <span class="o">+</span> <span class="n">simdgroup_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">simdgroup_size</span> <span class="o">*</span> <span class="n">simdgroup_size</span>
<span class="n">grid_size</span> <span class="o">=</span> <span class="n">B</span> <span class="o">*</span> <span class="n">gN</span> <span class="o">*</span> <span class="n">gM</span> <span class="o">*</span> <span class="n">C_padded</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">],</span>
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">&quot;T&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">grid</span><span class="o">.</span><span class="n">shape</span><span class="p">],</span>
<span class="n">output_dtypes</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span>
<span class="n">grid</span><span class="o">=</span><span class="p">(</span><span class="n">grid_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">threadgroup</span><span class="o">=</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">init_value</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</pre></div>
</div>
<p>Theres an even larger speed up for the vjp:</p>
<p><code class="docutils literal notranslate"><span class="pre">676.4ms</span> <span class="pre">-&gt;</span> <span class="pre">16.7ms</span> <span class="pre">=&gt;</span> <span class="pre">40x</span> <span class="pre">speed</span> <span class="pre">up</span></code></p>
</section>
2024-08-24 03:14:53 +08:00
</section>
</article>
2024-10-15 23:12:17 +08:00
<footer class="prev-next-footer d-print-none">
2024-08-24 03:14:53 +08:00
<div class="prev-next-area">
<a class="left-prev"
href="metal_debugger.html"
title="previous page">
<i class="fa-solid fa-angle-left"></i>
<div class="prev-next-info">
<p class="prev-next-subtitle">previous</p>
<p class="prev-next-title">Metal Debugger</p>
</div>
</a>
2025-01-10 05:56:20 +08:00
<a class="right-next"
href="mlx_in_cpp.html"
title="next page">
<div class="prev-next-info">
<p class="prev-next-subtitle">next</p>
<p class="prev-next-title">Using MLX in C++</p>
</div>
<i class="fa-solid fa-angle-right"></i>
</a>
2024-08-24 03:14:53 +08:00
</div>
</footer>
</div>
2025-03-06 05:30:09 +08:00
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
<div class="sidebar-secondary-item">
<div class="page-toc tocsection onthispage">
<i class="fa-solid fa-list"></i> Contents
</div>
<nav class="bd-toc-nav page-toc">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#simple-example">Simple Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-shape-strides">Using Shape/Strides</a></li>
2024-09-18 03:06:14 +08:00
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#complex-example">Complex Example</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#grid-sample-vjp">Grid Sample VJP</a></li>
2024-08-24 03:14:53 +08:00
</ul>
</nav></div>
</div></div>
</div>
<footer class="bd-footer-content">
<div class="bd-footer-content__inner container">
<div class="footer-item">
<p class="component-author">
By MLX Contributors
</p>
</div>
<div class="footer-item">
<p class="copyright">
2025-06-03 07:29:32 +08:00
© Copyright 2023, Apple.
2024-08-24 03:14:53 +08:00
<br/>
</p>
</div>
<div class="footer-item">
</div>
<div class="footer-item">
</div>
</div>
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
2025-03-06 05:30:09 +08:00
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
2024-08-24 03:14:53 +08:00
<footer class="bd-footer">
</footer>
</body>
</html>