2023-12-06 04:10:03 +08:00
2023-11-30 04:41:56 +08:00
<!DOCTYPE html>
2023-12-06 04:10:03 +08:00
2024-03-31 08:32:20 +08:00
< html lang = "en" data-content_root = "../" >
2023-12-06 04:10:03 +08:00
< head >
< meta charset = "utf-8" / >
2024-10-15 23:12:17 +08:00
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / > < meta name = "viewport" content = "width=device-width, initial-scale=1" / >
2023-12-06 04:10:03 +08:00
2025-06-04 09:03:47 +08:00
< title > Custom Extensions in MLX — MLX 0.26.1 documentation< / title >
2023-12-06 04:10:03 +08:00
< script data-cfasync = "false" >
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
2024-10-15 23:12:17 +08:00
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
2023-12-06 04:10:03 +08:00
< / script >
2023-11-30 04:41:56 +08:00
2023-12-06 04:10:03 +08:00
<!-- Loaded before other Sphinx assets -->
2025-03-06 05:30:09 +08:00
< link href = "../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" / >
2023-12-06 04:10:03 +08:00
2025-01-10 05:56:20 +08:00
< link rel = "stylesheet" type = "text/css" href = "../_static/pygments.css?v=03e43079" / >
2025-03-06 05:30:09 +08:00
< link rel = "stylesheet" type = "text/css" href = "../_static/styles/sphinx-book-theme.css?v=eba8b062" / >
2023-12-06 04:10:03 +08:00
<!-- Pre - loaded scripts that we'll load fully later -->
2025-03-06 05:30:09 +08:00
< link rel = "preload" as = "script" href = "../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" / >
< link rel = "preload" as = "script" href = "../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" / >
< script src = "../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b" > < / script >
2023-12-06 04:10:03 +08:00
2025-06-04 09:03:47 +08:00
< script src = "../_static/documentation_options.js?v=3724ff34" > < / script >
2024-10-15 23:12:17 +08:00
< script src = "../_static/doctools.js?v=9a2dae69" > < / script >
2024-03-31 08:32:20 +08:00
< script src = "../_static/sphinx_highlight.js?v=dc90522c" > < / script >
2024-10-15 23:12:17 +08:00
< script src = "../_static/scripts/sphinx-book-theme.js?v=887ef09a" > < / script >
2023-12-06 04:10:03 +08:00
< script > DOCUMENTATION _OPTIONS . pagename = 'dev/extensions' ; < / script >
2024-10-31 11:00:19 +08:00
< link rel = "icon" href = "../_static/mlx_logo.png" / >
2023-11-30 04:41:56 +08:00
< link rel = "index" title = "Index" href = "../genindex.html" / >
< link rel = "search" title = "Search" href = "../search.html" / >
2024-03-31 08:32:20 +08:00
< link rel = "next" title = "Metal Debugger" href = "metal_debugger.html" / >
2023-12-06 04:10:03 +08:00
< link rel = "prev" title = "Operations" href = "../cpp/ops.html" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1" / >
< meta name = "docsearch:language" content = "en" / >
< / head >
< body data-bs-spy = "scroll" data-bs-target = ".bd-toc-nav" data-offset = "180" data-bs-root-margin = "0px 0px -60%" data-default-mode = "" >
2024-10-15 23:12:17 +08:00
< div id = "pst-skip-link" class = "skip-link d-print-none" > < a href = "#main-content" > Skip to main content< / a > < / div >
2023-12-06 04:10:03 +08:00
< div id = "pst-scroll-pixel-helper" > < / div >
< button type = "button" class = "btn rounded-pill" id = "pst-back-to-top" >
2024-10-15 23:12:17 +08:00
< i class = "fa-solid fa-arrow-up" > < / i > Back to top< / button >
2023-12-06 04:10:03 +08:00
2025-03-06 05:30:09 +08:00
< input type = "checkbox"
class="sidebar-toggle"
id="pst-primary-sidebar-checkbox"/>
< label class = "overlay overlay-primary" for = "pst-primary-sidebar-checkbox" > < / label >
< input type = "checkbox"
class="sidebar-toggle"
id="pst-secondary-sidebar-checkbox"/>
< label class = "overlay overlay-secondary" for = "pst-secondary-sidebar-checkbox" > < / label >
< div class = "search-button__wrapper" >
< div class = "search-button__overlay" > < / div >
< div class = "search-button__search-container" >
2023-12-06 04:10:03 +08:00
< form class = "bd-search d-flex align-items-center"
action="../search.html"
method="get">
< i class = "fa-solid fa-magnifying-glass" > < / i >
< input type = "search"
class="form-control"
name="q"
2025-03-06 05:30:09 +08:00
id="search-input"
2023-12-06 04:10:03 +08:00
placeholder="Search..."
aria-label="Search..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
< span class = "search-button__kbd-shortcut" > < kbd class = "kbd-shortcut__modifier" > Ctrl< / kbd > +< kbd > K< / kbd > < / span >
2025-03-06 05:30:09 +08:00
< / form > < / div >
< / div >
2024-10-15 23:12:17 +08:00
< div class = "pst-async-banner-revealer d-none" >
< aside id = "bd-header-version-warning" class = "d-none d-print-none" aria-label = "Version warning" > < / aside >
< / div >
2023-12-06 04:10:03 +08:00
2024-10-15 23:12:17 +08:00
< header class = "bd-header navbar navbar-expand-lg bd-navbar d-print-none" >
< / header >
2023-12-06 04:10:03 +08:00
2024-10-15 23:12:17 +08:00
2023-12-06 04:10:03 +08:00
< div class = "bd-container" >
< div class = "bd-container__inner bd-page-width" >
2024-10-15 23:12:17 +08:00
2025-03-06 05:30:09 +08:00
< div class = "bd-sidebar-primary bd-sidebar" >
2023-12-06 04:10:03 +08:00
< div class = "sidebar-header-items sidebar-primary__section" >
< / div >
< div class = "sidebar-primary-items__start sidebar-primary__section" >
< div class = "sidebar-primary-item" >
2024-10-15 23:12:17 +08:00
2023-12-06 04:10:03 +08:00
< a class = "navbar-brand logo" href = "../index.html" >
2025-06-04 09:03:47 +08:00
< img src = "../_static/mlx_logo.png" class = "logo__image only-light" alt = "MLX 0.26.1 documentation - Home" / >
< script > document . write ( ` <img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/> ` ) ; < / script >
2023-12-06 04:10:03 +08:00
< / a > < / div >
2024-03-31 08:32:20 +08:00
< div class = "sidebar-primary-item" >
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn search-button-field search-button__button" title = "Search" aria-label = "Search" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
< i class = "fa-solid fa-magnifying-glass" > < / i >
< span class = "search-button__default-text" > Search< / span >
< span class = "search-button__kbd-shortcut" > < kbd class = "kbd-shortcut__modifier" > Ctrl< / kbd > +< kbd class = "kbd-shortcut__modifier" > K< / kbd > < / span >
< / button >
`);
< / script > < / div >
2024-03-31 08:32:20 +08:00
< div class = "sidebar-primary-item" > < nav class = "bd-links bd-docs-nav" aria-label = "Main" >
2023-12-06 04:10:03 +08:00
< div class = "bd-toc-item navbar-nav active" >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Install< / span > < / p >
< ul class = "nav bd-sidenav" >
2023-11-30 04:41:56 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../install.html" > Build and Install< / a > < / li >
< / ul >
2023-12-06 04:10:03 +08:00
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Usage< / span > < / p >
< ul class = "nav bd-sidenav" >
2024-01-06 21:41:48 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/quick_start.html" > Quick Start Guide< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/lazy_evaluation.html" > Lazy Evaluation< / a > < / li >
2024-01-06 21:41:48 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/unified_memory.html" > Unified Memory< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/indexing.html" > Indexing Arrays< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/saving_and_loading.html" > Saving and Loading Arrays< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/function_transforms.html" > Function Transforms< / a > < / li >
2024-02-09 04:44:23 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/compile.html" > Compilation< / a > < / li >
2024-01-06 21:41:48 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/numpy.html" > Conversion to NumPy and Other Frameworks< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/distributed.html" > Distributed Communication< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/using_streams.html" > Using Streams< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/export.html" > Exporting Functions< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2023-12-06 04:10:03 +08:00
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Examples< / span > < / p >
< ul class = "nav bd-sidenav" >
2023-11-30 04:41:56 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/linear_regression.html" > Linear Regression< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/mlp.html" > Multi-Layer Perceptron< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/llama-inference.html" > LLM inference< / a > < / li >
< / ul >
2023-12-06 04:10:03 +08:00
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Python API Reference< / span > < / p >
< ul class = "nav bd-sidenav" >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/array.html" > Array< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.html" > mlx.core.array< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.astype.html" > mlx.core.array.astype< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.at.html" > mlx.core.array.at< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.item.html" > mlx.core.array.item< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.tolist.html" > mlx.core.array.tolist< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.dtype.html" > mlx.core.array.dtype< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.itemsize.html" > mlx.core.array.itemsize< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.nbytes.html" > mlx.core.array.nbytes< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.ndim.html" > mlx.core.array.ndim< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.shape.html" > mlx.core.array.shape< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.size.html" > mlx.core.array.size< / a > < / li >
2025-06-03 07:29:32 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.real.html" > mlx.core.array.real< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.imag.html" > mlx.core.array.imag< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.abs.html" > mlx.core.array.abs< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.all.html" > mlx.core.array.all< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.any.html" > mlx.core.array.any< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.argmax.html" > mlx.core.array.argmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.argmin.html" > mlx.core.array.argmin< / a > < / li >
2024-07-12 06:32:08 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.conj.html" > mlx.core.array.conj< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cos.html" > mlx.core.array.cos< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cummax.html" > mlx.core.array.cummax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cummin.html" > mlx.core.array.cummin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cumprod.html" > mlx.core.array.cumprod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cumsum.html" > mlx.core.array.cumsum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.diag.html" > mlx.core.array.diag< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.diagonal.html" > mlx.core.array.diagonal< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.exp.html" > mlx.core.array.exp< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.flatten.html" > mlx.core.array.flatten< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log.html" > mlx.core.array.log< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log10.html" > mlx.core.array.log10< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log1p.html" > mlx.core.array.log1p< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log2.html" > mlx.core.array.log2< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.logcumsumexp.html" > mlx.core.array.logcumsumexp< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.logsumexp.html" > mlx.core.array.logsumexp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.max.html" > mlx.core.array.max< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.mean.html" > mlx.core.array.mean< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.min.html" > mlx.core.array.min< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.moveaxis.html" > mlx.core.array.moveaxis< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.prod.html" > mlx.core.array.prod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.reciprocal.html" > mlx.core.array.reciprocal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.reshape.html" > mlx.core.array.reshape< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.round.html" > mlx.core.array.round< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.rsqrt.html" > mlx.core.array.rsqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sin.html" > mlx.core.array.sin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.split.html" > mlx.core.array.split< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sqrt.html" > mlx.core.array.sqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.square.html" > mlx.core.array.square< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.squeeze.html" > mlx.core.array.squeeze< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.std.html" > mlx.core.array.std< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sum.html" > mlx.core.array.sum< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.swapaxes.html" > mlx.core.array.swapaxes< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.transpose.html" > mlx.core.array.transpose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.T.html" > mlx.core.array.T< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.var.html" > mlx.core.array.var< / a > < / li >
2024-07-12 06:32:08 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.view.html" > mlx.core.array.view< / a > < / li >
2023-12-06 04:10:03 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/data_types.html" > Data Types< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.Dtype.html" > mlx.core.Dtype< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.DtypeCategory.html" > mlx.core.DtypeCategory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.issubdtype.html" > mlx.core.issubdtype< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.finfo.html" > mlx.core.finfo< / a > < / li >
2024-03-31 08:32:20 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/devices_and_streams.html" > Devices and Streams< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.Device.html" > mlx.core.Device< / a > < / li >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/stream_class.html" > mlx.core.Stream< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.default_device.html" > mlx.core.default_device< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_default_device.html" > mlx.core.set_default_device< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.default_stream.html" > mlx.core.default_stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.new_stream.html" > mlx.core.new_stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_default_stream.html" > mlx.core.set_default_stream< / a > < / li >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stream.html" > mlx.core.stream< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.synchronize.html" > mlx.core.synchronize< / a > < / li >
2023-12-06 04:10:03 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/export.html" > Export Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.export_function.html" > mlx.core.export_function< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.import_function.html" > mlx.core.import_function< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.exporter.html" > mlx.core.exporter< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.export_to_dot.html" > mlx.core.export_to_dot< / a > < / li >
< / ul >
< / details > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/ops.html" > Operations< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.abs.html" > mlx.core.abs< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.add.html" > mlx.core.add< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.addmm.html" > mlx.core.addmm< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.all.html" > mlx.core.all< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.allclose.html" > mlx.core.allclose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.any.html" > mlx.core.any< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arange.html" > mlx.core.arange< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arccos.html" > mlx.core.arccos< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arccosh.html" > mlx.core.arccosh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arcsin.html" > mlx.core.arcsin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arcsinh.html" > mlx.core.arcsinh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctan.html" > mlx.core.arctan< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctan2.html" > mlx.core.arctan2< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctanh.html" > mlx.core.arctanh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argmax.html" > mlx.core.argmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argmin.html" > mlx.core.argmin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argpartition.html" > mlx.core.argpartition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argsort.html" > mlx.core.argsort< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array_equal.html" > mlx.core.array_equal< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.as_strided.html" > mlx.core.as_strided< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_1d.html" > mlx.core.atleast_1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_2d.html" > mlx.core.atleast_2d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_3d.html" > mlx.core.atleast_3d< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_and.html" > mlx.core.bitwise_and< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_invert.html" > mlx.core.bitwise_invert< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_or.html" > mlx.core.bitwise_or< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_xor.html" > mlx.core.bitwise_xor< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.block_masked_mm.html" > mlx.core.block_masked_mm< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.broadcast_arrays.html" > mlx.core.broadcast_arrays< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.broadcast_to.html" > mlx.core.broadcast_to< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ceil.html" > mlx.core.ceil< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.clip.html" > mlx.core.clip< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.concatenate.html" > mlx.core.concatenate< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.contiguous.html" > mlx.core.contiguous< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conj.html" > mlx.core.conj< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conjugate.html" > mlx.core.conjugate< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.convolve.html" > mlx.core.convolve< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv1d.html" > mlx.core.conv1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv2d.html" > mlx.core.conv2d< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv3d.html" > mlx.core.conv3d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose1d.html" > mlx.core.conv_transpose1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose2d.html" > mlx.core.conv_transpose2d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose3d.html" > mlx.core.conv_transpose3d< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_general.html" > mlx.core.conv_general< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cos.html" > mlx.core.cos< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cosh.html" > mlx.core.cosh< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cummax.html" > mlx.core.cummax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cummin.html" > mlx.core.cummin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cumprod.html" > mlx.core.cumprod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cumsum.html" > mlx.core.cumsum< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.degrees.html" > mlx.core.degrees< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.dequantize.html" > mlx.core.dequantize< / a > < / li >
2024-02-02 05:08:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.diag.html" > mlx.core.diag< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.diagonal.html" > mlx.core.diagonal< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.divide.html" > mlx.core.divide< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.divmod.html" > mlx.core.divmod< / a > < / li >
2024-07-26 02:59:11 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.einsum.html" > mlx.core.einsum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.einsum_path.html" > mlx.core.einsum_path< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.equal.html" > mlx.core.equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.erf.html" > mlx.core.erf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.erfinv.html" > mlx.core.erfinv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.exp.html" > mlx.core.exp< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.expm1.html" > mlx.core.expm1< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.expand_dims.html" > mlx.core.expand_dims< / a > < / li >
2023-12-14 06:46:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.eye.html" > mlx.core.eye< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.flatten.html" > mlx.core.flatten< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.floor.html" > mlx.core.floor< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.floor_divide.html" > mlx.core.floor_divide< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.full.html" > mlx.core.full< / a > < / li >
2024-05-24 12:11:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.gather_mm.html" > mlx.core.gather_mm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.gather_qmm.html" > mlx.core.gather_qmm< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.greater.html" > mlx.core.greater< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.greater_equal.html" > mlx.core.greater_equal< / a > < / li >
2024-07-12 06:32:08 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.hadamard_transform.html" > mlx.core.hadamard_transform< / a > < / li >
2023-12-14 06:46:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.identity.html" > mlx.core.identity< / a > < / li >
2024-10-19 03:13:44 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.imag.html" > mlx.core.imag< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.inner.html" > mlx.core.inner< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isfinite.html" > mlx.core.isfinite< / a > < / li >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isclose.html" > mlx.core.isclose< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isinf.html" > mlx.core.isinf< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isnan.html" > mlx.core.isnan< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isneginf.html" > mlx.core.isneginf< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isposinf.html" > mlx.core.isposinf< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.issubdtype.html" > mlx.core.issubdtype< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.kron.html" > mlx.core.kron< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.left_shift.html" > mlx.core.left_shift< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.less.html" > mlx.core.less< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.less_equal.html" > mlx.core.less_equal< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linspace.html" > mlx.core.linspace< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.load.html" > mlx.core.load< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log.html" > mlx.core.log< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log2.html" > mlx.core.log2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log10.html" > mlx.core.log10< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log1p.html" > mlx.core.log1p< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logaddexp.html" > mlx.core.logaddexp< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logcumsumexp.html" > mlx.core.logcumsumexp< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_not.html" > mlx.core.logical_not< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_and.html" > mlx.core.logical_and< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_or.html" > mlx.core.logical_or< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logsumexp.html" > mlx.core.logsumexp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.matmul.html" > mlx.core.matmul< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.max.html" > mlx.core.max< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.maximum.html" > mlx.core.maximum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.mean.html" > mlx.core.mean< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.meshgrid.html" > mlx.core.meshgrid< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.min.html" > mlx.core.min< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.minimum.html" > mlx.core.minimum< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.moveaxis.html" > mlx.core.moveaxis< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.multiply.html" > mlx.core.multiply< / a > < / li >
2024-07-26 02:59:11 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.nan_to_num.html" > mlx.core.nan_to_num< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.negative.html" > mlx.core.negative< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.not_equal.html" > mlx.core.not_equal< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ones.html" > mlx.core.ones< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ones_like.html" > mlx.core.ones_like< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.outer.html" > mlx.core.outer< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.partition.html" > mlx.core.partition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.pad.html" > mlx.core.pad< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.power.html" > mlx.core.power< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.prod.html" > mlx.core.prod< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.put_along_axis.html" > mlx.core.put_along_axis< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.quantize.html" > mlx.core.quantize< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.quantized_matmul.html" > mlx.core.quantized_matmul< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.radians.html" > mlx.core.radians< / a > < / li >
2024-10-19 03:13:44 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.real.html" > mlx.core.real< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reciprocal.html" > mlx.core.reciprocal< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.remainder.html" > mlx.core.remainder< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.repeat.html" > mlx.core.repeat< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reshape.html" > mlx.core.reshape< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.right_shift.html" > mlx.core.right_shift< / a > < / li >
2024-10-15 04:10:48 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.roll.html" > mlx.core.roll< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.round.html" > mlx.core.round< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.rsqrt.html" > mlx.core.rsqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save.html" > mlx.core.save< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.savez.html" > mlx.core.savez< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.savez_compressed.html" > mlx.core.savez_compressed< / a > < / li >
2024-01-11 06:14:38 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save_gguf.html" > mlx.core.save_gguf< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save_safetensors.html" > mlx.core.save_safetensors< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sigmoid.html" > mlx.core.sigmoid< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sign.html" > mlx.core.sign< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sin.html" > mlx.core.sin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sinh.html" > mlx.core.sinh< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.slice.html" > mlx.core.slice< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.slice_update.html" > mlx.core.slice_update< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.softmax.html" > mlx.core.softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sort.html" > mlx.core.sort< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.split.html" > mlx.core.split< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sqrt.html" > mlx.core.sqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.square.html" > mlx.core.square< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.squeeze.html" > mlx.core.squeeze< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stack.html" > mlx.core.stack< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.std.html" > mlx.core.std< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stop_gradient.html" > mlx.core.stop_gradient< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.subtract.html" > mlx.core.subtract< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sum.html" > mlx.core.sum< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.swapaxes.html" > mlx.core.swapaxes< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.take.html" > mlx.core.take< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.take_along_axis.html" > mlx.core.take_along_axis< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tan.html" > mlx.core.tan< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tanh.html" > mlx.core.tanh< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tensordot.html" > mlx.core.tensordot< / a > < / li >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tile.html" > mlx.core.tile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.topk.html" > mlx.core.topk< / a > < / li >
2024-05-24 12:11:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.trace.html" > mlx.core.trace< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.transpose.html" > mlx.core.transpose< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tri.html" > mlx.core.tri< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tril.html" > mlx.core.tril< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.triu.html" > mlx.core.triu< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.unflatten.html" > mlx.core.unflatten< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.var.html" > mlx.core.var< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.view.html" > mlx.core.view< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.where.html" > mlx.core.where< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.zeros.html" > mlx.core.zeros< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.zeros_like.html" > mlx.core.zeros_like< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/random.html" > Random< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.bernoulli.html" > mlx.core.random.bernoulli< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.categorical.html" > mlx.core.random.categorical< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.gumbel.html" > mlx.core.random.gumbel< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.key.html" > mlx.core.random.key< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.normal.html" > mlx.core.random.normal< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.multivariate_normal.html" > mlx.core.random.multivariate_normal< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.randint.html" > mlx.core.random.randint< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.seed.html" > mlx.core.random.seed< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.split.html" > mlx.core.random.split< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.truncated_normal.html" > mlx.core.random.truncated_normal< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.uniform.html" > mlx.core.random.uniform< / a > < / li >
2024-07-26 02:59:11 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.laplace.html" > mlx.core.random.laplace< / a > < / li >
2024-10-15 04:10:48 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.permutation.html" > mlx.core.random.permutation< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/transforms.html" > Transforms< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.eval.html" > mlx.core.eval< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.async_eval.html" > mlx.core.async_eval< / a > < / li >
2024-02-09 04:44:23 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.compile.html" > mlx.core.compile< / a > < / li >
2024-07-12 06:32:08 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.custom_function.html" > mlx.core.custom_function< / a > < / li >
2024-02-09 04:44:23 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.disable_compile.html" > mlx.core.disable_compile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.enable_compile.html" > mlx.core.enable_compile< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.grad.html" > mlx.core.grad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.value_and_grad.html" > mlx.core.value_and_grad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.jvp.html" > mlx.core.jvp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.vjp.html" > mlx.core.vjp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.vmap.html" > mlx.core.vmap< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/fast.html" > Fast< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.rms_norm.html" > mlx.core.fast.rms_norm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.layer_norm.html" > mlx.core.fast.layer_norm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.rope.html" > mlx.core.fast.rope< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html" > mlx.core.fast.scaled_dot_product_attention< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.metal_kernel.html" > mlx.core.fast.metal_kernel< / a > < / li >
2024-03-31 08:32:20 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/fft.html" > FFT< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fft.html" > mlx.core.fft.fft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifft.html" > mlx.core.fft.ifft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fft2.html" > mlx.core.fft.fft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifft2.html" > mlx.core.fft.ifft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fftn.html" > mlx.core.fft.fftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifftn.html" > mlx.core.fft.ifftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfft.html" > mlx.core.fft.rfft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfft.html" > mlx.core.fft.irfft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfft2.html" > mlx.core.fft.rfft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfft2.html" > mlx.core.fft.irfft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfftn.html" > mlx.core.fft.rfftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfftn.html" > mlx.core.fft.irfftn< / a > < / li >
2025-05-10 05:42:00 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fftshift.html" > mlx.core.fft.fftshift< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifftshift.html" > mlx.core.fft.ifftshift< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/linalg.html" > Linear Algebra< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.inv.html" > mlx.core.linalg.inv< / a > < / li >
2024-08-11 00:24:35 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.tri_inv.html" > mlx.core.linalg.tri_inv< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.norm.html" > mlx.core.linalg.norm< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cholesky.html" > mlx.core.linalg.cholesky< / a > < / li >
2024-08-11 00:24:35 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cholesky_inv.html" > mlx.core.linalg.cholesky_inv< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cross.html" > mlx.core.linalg.cross< / a > < / li >
2024-02-02 05:08:29 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.qr.html" > mlx.core.linalg.qr< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.svd.html" > mlx.core.linalg.svd< / a > < / li >
2025-06-03 07:29:32 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigvals.html" > mlx.core.linalg.eigvals< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eig.html" > mlx.core.linalg.eig< / a > < / li >
2024-10-26 04:23:45 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigvalsh.html" > mlx.core.linalg.eigvalsh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigh.html" > mlx.core.linalg.eigh< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.lu.html" > mlx.core.linalg.lu< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.lu_factor.html" > mlx.core.linalg.lu_factor< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.pinv.html" > mlx.core.linalg.pinv< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.solve.html" > mlx.core.linalg.solve< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.solve_triangular.html" > mlx.core.linalg.solve_triangular< / a > < / li >
2024-01-04 12:14:05 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/metal.html" > Metal< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.is_available.html" > mlx.core.metal.is_available< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.device_info.html" > mlx.core.metal.device_info< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.start_capture.html" > mlx.core.metal.start_capture< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.stop_capture.html" > mlx.core.metal.stop_capture< / a > < / li >
2024-03-15 03:46:45 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2025-03-25 04:24:41 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/memory_management.html" > Memory Management< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_active_memory.html" > mlx.core.get_active_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_peak_memory.html" > mlx.core.get_peak_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reset_peak_memory.html" > mlx.core.reset_peak_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_cache_memory.html" > mlx.core.get_cache_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_memory_limit.html" > mlx.core.set_memory_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_cache_limit.html" > mlx.core.set_cache_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_wired_limit.html" > mlx.core.set_wired_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.clear_cache.html" > mlx.core.clear_cache< / a > < / li >
< / ul >
< / details > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/nn.html" > Neural Networks< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.value_and_grad.html" > mlx.nn.value_and_grad< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.quantize.html" > mlx.nn.quantize< / a > < / li >
2025-03-06 05:30:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.average_gradients.html" > mlx.nn.average_gradients< / a > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/module.html" > Module< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.training.html" > mlx.nn.Module.training< / a > < / li >
2024-02-09 04:44:23 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.state.html" > mlx.nn.Module.state< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.apply.html" > mlx.nn.Module.apply< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html" > mlx.nn.Module.apply_to_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.children.html" > mlx.nn.Module.children< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.eval.html" > mlx.nn.Module.eval< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html" > mlx.nn.Module.filter_and_map< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.freeze.html" > mlx.nn.Module.freeze< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html" > mlx.nn.Module.leaf_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.load_weights.html" > mlx.nn.Module.load_weights< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.modules.html" > mlx.nn.Module.modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.named_modules.html" > mlx.nn.Module.named_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.parameters.html" > mlx.nn.Module.parameters< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.save_weights.html" > mlx.nn.Module.save_weights< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.set_dtype.html" > mlx.nn.Module.set_dtype< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.train.html" > mlx.nn.Module.train< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html" > mlx.nn.Module.trainable_parameters< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.unfreeze.html" > mlx.nn.Module.unfreeze< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.update.html" > mlx.nn.Module.update< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.update_modules.html" > mlx.nn.Module.update_modules< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/layers.html" > Layers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ALiBi.html" > mlx.nn.ALiBi< / a > < / li >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool1d.html" > mlx.nn.AvgPool1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool2d.html" > mlx.nn.AvgPool2d< / a > < / li >
2024-11-23 04:24:16 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool3d.html" > mlx.nn.AvgPool3d< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.BatchNorm.html" > mlx.nn.BatchNorm< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.CELU.html" > mlx.nn.CELU< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv1d.html" > mlx.nn.Conv1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv2d.html" > mlx.nn.Conv2d< / a > < / li >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv3d.html" > mlx.nn.Conv3d< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html" > mlx.nn.ConvTranspose1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html" > mlx.nn.ConvTranspose2d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html" > mlx.nn.ConvTranspose3d< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout.html" > mlx.nn.Dropout< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout2d.html" > mlx.nn.Dropout2d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout3d.html" > mlx.nn.Dropout3d< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Embedding.html" > mlx.nn.Embedding< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ELU.html" > mlx.nn.ELU< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GELU.html" > mlx.nn.GELU< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GLU.html" > mlx.nn.GLU< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GroupNorm.html" > mlx.nn.GroupNorm< / a > < / li >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GRU.html" > mlx.nn.GRU< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.HardShrink.html" > mlx.nn.HardShrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.HardTanh.html" > mlx.nn.HardTanh< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Hardswish.html" > mlx.nn.Hardswish< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.InstanceNorm.html" > mlx.nn.InstanceNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LayerNorm.html" > mlx.nn.LayerNorm< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LeakyReLU.html" > mlx.nn.LeakyReLU< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Linear.html" > mlx.nn.Linear< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LogSigmoid.html" > mlx.nn.LogSigmoid< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LogSoftmax.html" > mlx.nn.LogSoftmax< / a > < / li >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LSTM.html" > mlx.nn.LSTM< / a > < / li >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool1d.html" > mlx.nn.MaxPool1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool2d.html" > mlx.nn.MaxPool2d< / a > < / li >
2024-11-23 04:24:16 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool3d.html" > mlx.nn.MaxPool3d< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Mish.html" > mlx.nn.Mish< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html" > mlx.nn.MultiHeadAttention< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.PReLU.html" > mlx.nn.PReLU< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.QuantizedEmbedding.html" > mlx.nn.QuantizedEmbedding< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.QuantizedLinear.html" > mlx.nn.QuantizedLinear< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RMSNorm.html" > mlx.nn.RMSNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ReLU.html" > mlx.nn.ReLU< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ReLU6.html" > mlx.nn.ReLU6< / a > < / li >
2024-03-15 03:46:45 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RNN.html" > mlx.nn.RNN< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RoPE.html" > mlx.nn.RoPE< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SELU.html" > mlx.nn.SELU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Sequential.html" > mlx.nn.Sequential< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Sigmoid.html" > mlx.nn.Sigmoid< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SiLU.html" > mlx.nn.SiLU< / a > < / li >
2024-01-04 12:14:05 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html" > mlx.nn.SinusoidalPositionalEncoding< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softmin.html" > mlx.nn.Softmin< / a > < / li >
2024-02-02 05:08:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softshrink.html" > mlx.nn.Softshrink< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softsign.html" > mlx.nn.Softsign< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softmax.html" > mlx.nn.Softmax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softplus.html" > mlx.nn.Softplus< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Step.html" > mlx.nn.Step< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Tanh.html" > mlx.nn.Tanh< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Transformer.html" > mlx.nn.Transformer< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Upsample.html" > mlx.nn.Upsample< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/functions.html" > Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.elu.html" > mlx.nn.elu< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.celu.html" > mlx.nn.celu< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu.html" > mlx.nn.gelu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html" > mlx.nn.gelu_approx< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html" > mlx.nn.gelu_fast_approx< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.glu.html" > mlx.nn.glu< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hard_shrink.html" > mlx.nn.hard_shrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hard_tanh.html" > mlx.nn.hard_tanh< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hardswish.html" > mlx.nn.hardswish< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.leaky_relu.html" > mlx.nn.leaky_relu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.log_sigmoid.html" > mlx.nn.log_sigmoid< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.log_softmax.html" > mlx.nn.log_softmax< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.mish.html" > mlx.nn.mish< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.prelu.html" > mlx.nn.prelu< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.relu.html" > mlx.nn.relu< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.relu6.html" > mlx.nn.relu6< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.selu.html" > mlx.nn.selu< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.sigmoid.html" > mlx.nn.sigmoid< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.silu.html" > mlx.nn.silu< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softmax.html" > mlx.nn.softmax< / a > < / li >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softmin.html" > mlx.nn.softmin< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softplus.html" > mlx.nn.softplus< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softshrink.html" > mlx.nn.softshrink< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.step.html" > mlx.nn.step< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.tanh.html" > mlx.nn.tanh< / a > < / li >
2023-12-18 05:23:03 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/losses.html" > Loss Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html" > mlx.nn.losses.binary_cross_entropy< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html" > mlx.nn.losses.cosine_similarity_loss< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html" > mlx.nn.losses.cross_entropy< / a > < / li >
2024-02-02 05:08:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html" > mlx.nn.losses.gaussian_nll_loss< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html" > mlx.nn.losses.hinge_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html" > mlx.nn.losses.huber_loss< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html" > mlx.nn.losses.kl_div_loss< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html" > mlx.nn.losses.l1_loss< / a > < / li >
2024-01-18 09:15:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html" > mlx.nn.losses.log_cosh_loss< / a > < / li >
2024-02-09 04:44:23 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html" > mlx.nn.losses.margin_ranking_loss< / a > < / li >
2023-12-18 05:23:03 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html" > mlx.nn.losses.mse_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html" > mlx.nn.losses.nll_loss< / a > < / li >
2023-12-22 14:13:41 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html" > mlx.nn.losses.smooth_l1_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html" > mlx.nn.losses.triplet_loss< / a > < / li >
2023-12-18 05:23:03 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/init.html" > Initializers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-02-02 05:08:29 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.constant.html" > mlx.nn.init.constant< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.normal.html" > mlx.nn.init.normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.uniform.html" > mlx.nn.init.uniform< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.identity.html" > mlx.nn.init.identity< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.glorot_normal.html" > mlx.nn.init.glorot_normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.glorot_uniform.html" > mlx.nn.init.glorot_uniform< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.he_normal.html" > mlx.nn.init.he_normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.he_uniform.html" > mlx.nn.init.he_uniform< / a > < / li >
2023-12-18 05:23:03 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2024-02-02 05:08:29 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/optimizers.html" > Optimizers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/optimizer.html" > Optimizer< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.html" > mlx.optimizers.Optimizer.state< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html" > mlx.optimizers.Optimizer.apply_gradients< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.html" > mlx.optimizers.Optimizer.init< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.html" > mlx.optimizers.Optimizer.update< / a > < / li >
2024-02-09 04:44:23 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/common_optimizers.html" > Common Optimizers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.SGD.html" > mlx.optimizers.SGD< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.RMSprop.html" > mlx.optimizers.RMSprop< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adagrad.html" > mlx.optimizers.Adagrad< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adafactor.html" > mlx.optimizers.Adafactor< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.AdaDelta.html" > mlx.optimizers.AdaDelta< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adam.html" > mlx.optimizers.Adam< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.AdamW.html" > mlx.optimizers.AdamW< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adamax.html" > mlx.optimizers.Adamax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Lion.html" > mlx.optimizers.Lion< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.MultiOptimizer.html" > mlx.optimizers.MultiOptimizer< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/schedulers.html" > Schedulers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-02-18 05:25:37 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.cosine_decay.html" > mlx.optimizers.cosine_decay< / a > < / li >
2024-03-01 04:39:18 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.exponential_decay.html" > mlx.optimizers.exponential_decay< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.join_schedules.html" > mlx.optimizers.join_schedules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.linear_schedule.html" > mlx.optimizers.linear_schedule< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.step_decay.html" > mlx.optimizers.step_decay< / a > < / li >
2024-02-18 05:25:37 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.optimizers.clip_grad_norm.html" > mlx.optimizers.clip_grad_norm< / a > < / li >
2024-02-18 05:25:37 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/distributed.html" > Distributed Communication< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-06-07 11:28:06 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.Group.html" > mlx.core.distributed.Group< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.is_available.html" > mlx.core.distributed.is_available< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.init.html" > mlx.core.distributed.init< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.all_sum.html" > mlx.core.distributed.all_sum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.all_gather.html" > mlx.core.distributed.all_gather< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.send.html" > mlx.core.distributed.send< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.recv.html" > mlx.core.distributed.recv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.recv_like.html" > mlx.core.distributed.recv_like< / a > < / li >
2024-06-07 11:28:06 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/tree_utils.html" > Tree Utils< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2023-12-06 04:10:03 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_flatten.html" > mlx.utils.tree_flatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_unflatten.html" > mlx.utils.tree_unflatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_map.html" > mlx.utils.tree_map< / a > < / li >
2024-04-26 23:24:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_map_with_path.html" > mlx.utils.tree_map_with_path< / a > < / li >
2024-05-10 23:49:36 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_reduce.html" > mlx.utils.tree_reduce< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2023-12-06 04:10:03 +08:00
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > C++ API Reference< / span > < / p >
< ul class = "nav bd-sidenav" >
2023-11-30 04:41:56 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../cpp/ops.html" > Operations< / a > < / li >
2023-12-06 04:10:03 +08:00
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Further Reading< / span > < / p >
< ul class = "current nav bd-sidenav" >
2024-05-21 00:40:17 +08:00
< li class = "toctree-l1 current active" > < a class = "current reference internal" href = "#" > Custom Extensions in MLX< / a > < / li >
2024-03-31 08:32:20 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "metal_debugger.html" > Metal Debugger< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "custom_metal_kernels.html" > Custom Metal Kernels< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "mlx_in_cpp.html" > Using MLX in C++< / a > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
2023-12-06 04:10:03 +08:00
< / div >
< / nav > < / div >
< / div >
< div class = "sidebar-primary-items__end sidebar-primary__section" >
< / div >
2025-03-06 05:30:09 +08:00
< div id = "rtd-footer-container" > < / div >
2023-12-06 04:10:03 +08:00
2023-11-30 04:41:56 +08:00
< / div >
2023-12-06 04:10:03 +08:00
2024-10-15 23:12:17 +08:00
< main id = "main-content" class = "bd-main" role = "main" >
2023-12-06 04:10:03 +08:00
< div class = "sbt-scroll-pixel-helper" > < / div >
< div class = "bd-content" >
< div class = "bd-article-container" >
2024-10-15 23:12:17 +08:00
< div class = "bd-header-article d-print-none" >
2023-12-06 04:10:03 +08:00
< div class = "header-article-items header-article__inner" >
< div class = "header-article-items__start" >
2024-10-15 23:12:17 +08:00
< div class = "header-article-item" > < button class = "sidebar-toggle primary-toggle btn btn-sm" title = "Toggle primary sidebar" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2023-12-06 04:10:03 +08:00
< span class = "fa-solid fa-bars" > < / span >
2024-10-15 23:12:17 +08:00
< / button > < / div >
2023-12-06 04:10:03 +08:00
< / div >
< div class = "header-article-items__end" >
< div class = "header-article-item" >
< div class = "article-header-buttons" >
2023-11-30 04:41:56 +08:00
2023-12-06 04:10:03 +08:00
< a href = "https://github.com/ml-explore/mlx" target = "_blank"
class="btn btn-sm btn-source-repository-button"
title="Source repository"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fab fa-github" > < / i >
< / span >
< / a >
< div class = "dropdown dropdown-download-buttons" >
< button class = "btn dropdown-toggle" type = "button" data-bs-toggle = "dropdown" aria-expanded = "false" aria-label = "Download this page" >
< i class = "fas fa-download" > < / i >
< / button >
< ul class = "dropdown-menu" >
< li > < a href = "../_sources/dev/extensions.rst" target = "_blank"
class="btn btn-sm btn-download-source-button dropdown-item"
title="Download source file"
data-bs-placement="left" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-file" > < / i >
< / span >
< span class = "btn__text-container" > .rst< / span >
< / a >
< / li >
< li >
< button onclick = "window.print()"
class="btn btn-sm btn-download-pdf-button dropdown-item"
title="Print to PDF"
data-bs-placement="left" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-file-pdf" > < / i >
< / span >
< span class = "btn__text-container" > .pdf< / span >
< / button >
< / li >
2023-11-30 04:41:56 +08:00
< / ul >
< / div >
2023-12-06 04:10:03 +08:00
< button onclick = "toggleFullScreen()"
class="btn btn-sm btn-fullscreen-button"
title="Fullscreen mode"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-expand" > < / i >
< / span >
< / button >
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn btn-sm nav-link pst-navbar-icon theme-switch-button" title = "light/dark" aria-label = "light/dark" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
< i class = "theme-switch fa-solid fa-sun fa-lg" data-mode = "light" > < / i >
< i class = "theme-switch fa-solid fa-moon fa-lg" data-mode = "dark" > < / i >
< i class = "theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode = "auto" > < / i >
< / button >
`);
< / script >
2023-12-06 04:10:03 +08:00
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn btn-sm pst-navbar-icon search-button search-button__button" title = "Search" aria-label = "Search" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2023-12-06 04:10:03 +08:00
< i class = "fa-solid fa-magnifying-glass fa-lg" > < / i >
2025-03-06 05:30:09 +08:00
< / button >
`);
< / script >
2024-10-15 23:12:17 +08:00
< button class = "sidebar-toggle secondary-toggle btn btn-sm" title = "Toggle secondary sidebar" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2023-12-06 04:10:03 +08:00
< span class = "fa-solid fa-list" > < / span >
2024-10-15 23:12:17 +08:00
< / button >
2023-12-06 04:10:03 +08:00
< / div > < / div >
< / div >
< / div >
< / div >
< div id = "jb-print-docs-body" class = "onlyprint" >
2024-05-21 00:40:17 +08:00
< h1 > Custom Extensions in MLX< / h1 >
2023-12-06 04:10:03 +08:00
<!-- Table of contents -->
< div id = "print-main-content" >
< div id = "jb-print-toc" >
< div >
< h2 > Contents < / h2 >
< / div >
< nav aria-label = "Page" >
< ul class = "visible nav section-nav flex-column" >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#introducing-the-example" > Introducing the Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#operations-and-primitives" > Operations and Primitives< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#operations" > Operations< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#primitives" > Primitives< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#using-the-primitive" > Using the Primitive< / a > < / li >
2023-12-06 04:10:03 +08:00
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-primitive" > Implementing the Primitive< / a > < ul class = "visible nav section-nav flex-column" >
2024-04-12 08:33:33 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-cpu-back-end" > Implementing the CPU Back-end< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-gpu-back-end" > Implementing the GPU Back-end< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#primitive-transforms" > Primitive Transforms< / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-and-binding" > Building and Binding< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#binding-to-python" > Binding to Python< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-with-cmake" > Building with CMake< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-with-setuptools" > Building with < code class = "docutils literal notranslate" > < span class = "pre" > setuptools< / span > < / code > < / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#usage" > Usage< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#results" > Results< / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#scripts" > Scripts< / a > < / li >
< / ul >
< / nav >
< / div >
< / div >
< / div >
< div id = "searchbox" > < / div >
2024-10-15 23:12:17 +08:00
< article class = "bd-article" >
2023-12-06 04:10:03 +08:00
2024-05-21 00:40:17 +08:00
< section id = "custom-extensions-in-mlx" >
< h1 > Custom Extensions in MLX< a class = "headerlink" href = "#custom-extensions-in-mlx" title = "Link to this heading" > #< / a > < / h1 >
2024-04-12 08:33:33 +08:00
< p > You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.< / p >
2023-11-30 04:41:56 +08:00
< section id = "introducing-the-example" >
2024-03-31 08:32:20 +08:00
< h2 > Introducing the Example< a class = "headerlink" href = "#introducing-the-example" title = "Link to this heading" > #< / a > < / h2 >
2024-04-12 08:33:33 +08:00
< p > Let’ s say you would like an operation that takes in two arrays, < code class = "docutils literal notranslate" > < span class = "pre" > x< / span > < / code > and
< code class = "docutils literal notranslate" > < span class = "pre" > y< / span > < / code > , scales them both by coefficients < code class = "docutils literal notranslate" > < span class = "pre" > alpha< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > beta< / span > < / code > respectively,
and then adds them together to get the result < code class = "docutils literal notranslate" > < span class = "pre" > z< / span > < span class = "pre" > =< / span > < span class = "pre" > alpha< / span > < span class = "pre" > *< / span > < span class = "pre" > x< / span > < span class = "pre" > +< / span > < span class = "pre" > beta< / span > < span class = "pre" > *< / span > < span class = "pre" > y< / span > < / code > .
You can do that in MLX directly:< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "w" > < / span > < span class = "nn" > mlx.core< / span > < span class = "w" > < / span > < span class = "k" > as< / span > < span class = "w" > < / span > < span class = "nn" > mx< / span >
2023-11-30 04:41:56 +08:00
2025-01-10 05:56:20 +08:00
< span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > simple_axpby< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ,< / span > < span class = "n" > alpha< / span > < span class = "p" > :< / span > < span class = "nb" > float< / span > < span class = "p" > ,< / span > < span class = "n" > beta< / span > < span class = "p" > :< / span > < span class = "nb" > float< / span > < span class = "p" > )< / span > < span class = "o" > -> < / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > :< / span >
2023-11-30 04:41:56 +08:00
< span class = "k" > return< / span > < span class = "n" > alpha< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > +< / span > < span class = "n" > beta< / span > < span class = "o" > *< / span > < span class = "n" > y< / span >
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > This function performs that operation while leaving the implementation and
function transformations to MLX.< / p >
2025-03-21 06:37:22 +08:00
< p > However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:< / p >
2023-11-30 04:41:56 +08:00
< ul class = "simple" >
2024-04-12 08:33:33 +08:00
< li > < p > The structure of the MLX library.< / p > < / li >
2025-03-21 06:37:22 +08:00
< li > < p > Implementing a CPU operation.< / p > < / li >
2024-04-12 08:33:33 +08:00
< li > < p > Implementing a GPU operation using metal.< / p > < / li >
< li > < p > Adding the < code class = "docutils literal notranslate" > < span class = "pre" > vjp< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > jvp< / span > < / code > function transformation.< / p > < / li >
< li > < p > Building a custom extension and binding it to python.< / p > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
< / section >
< section id = "operations-and-primitives" >
2024-03-31 08:32:20 +08:00
< h2 > Operations and Primitives< a class = "headerlink" href = "#operations-and-primitives" title = "Link to this heading" > #< / a > < / h2 >
2024-04-12 08:33:33 +08:00
< p > Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let’ s start by discussing operations in
more detail.< / p >
2023-11-30 04:41:56 +08:00
< section id = "operations" >
2024-03-31 08:32:20 +08:00
< h3 > Operations< a class = "headerlink" href = "#operations" title = "Link to this heading" > #< / a > < / h3 >
2024-04-12 08:33:33 +08:00
< p > Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (< a class = "reference internal" href = "../cpp/ops.html#cpp-ops" > < span class = "std std-ref" > Operations< / span > < / a > ), and the Python API (< a class = "reference internal" href = "../python/ops.html#ops" > < span class = "std std-ref" > Operations< / span > < / a > ) binds them.< / p >
2025-03-21 06:37:22 +08:00
< p > We would like an operation < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > axpby()< / span > < / code > that takes in two arrays, < code class = "docutils literal notranslate" > < span class = "pre" > x< / span > < / code > and
2024-04-12 08:33:33 +08:00
< code class = "docutils literal notranslate" > < span class = "pre" > y< / span > < / code > , and two scalars, < code class = "docutils literal notranslate" > < span class = "pre" > alpha< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > beta< / span > < / code > . This is how to define it in
C++:< / p >
2023-11-30 04:41:56 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "cm" > /**< / span >
2024-01-04 12:14:05 +08:00
< span class = "cm" > * Scale and sum two vectors element-wise< / span >
2023-11-30 04:41:56 +08:00
< span class = "cm" > * z = alpha * x + beta * y< / span >
< span class = "cm" > *< / span >
2025-03-21 06:37:22 +08:00
< span class = "cm" > * Use NumPy-style broadcasting between x and y< / span >
2023-11-30 04:41:56 +08:00
< span class = "cm" > * Inputs are upcasted to floats if needed< / span >
2024-03-31 08:32:20 +08:00
< span class = "cm" > **/< / span >
< span class = "n" > array< / span > < span class = "w" > < / span > < span class = "nf" > axpby< / span > < span class = "p" > (< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array y< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for y< / span >
< span class = "w" > < / span > < span class = "n" > StreamOrDevice< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "p" > {}< / span > < span class = "w" > < / span > < span class = "c1" > // Stream on which to schedule the operation< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2025-03-21 06:37:22 +08:00
< p > The simplest way to implement this is with existing operations:< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > array< / span > < span class = "w" > < / span > < span class = "nf" > axpby< / span > < span class = "p" > (< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array y< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for y< / span >
< span class = "w" > < / span > < span class = "n" > StreamOrDevice< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "w" > < / span > < span class = "cm" > /* = {} */< / span > < span class = "w" > < / span > < span class = "c1" > // Stream on which to schedule the operation< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Scale x and y on the provided stream< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > ax< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > multiply< / span > < span class = "p" > (< / span > < span class = "n" > array< / span > < span class = "p" > (< / span > < span class = "n" > alpha< / span > < span class = "p" > ),< / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > by< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > multiply< / span > < span class = "p" > (< / span > < span class = "n" > array< / span > < span class = "p" > (< / span > < span class = "n" > beta< / span > < span class = "p" > ),< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Add and return< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > add< / span > < span class = "p" > (< / span > < span class = "n" > ax< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > by< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > building blocks.< / p >
2023-11-30 04:41:56 +08:00
< / section >
< section id = "primitives" >
2024-03-31 08:32:20 +08:00
< h3 > Primitives< a class = "headerlink" href = "#primitives" title = "Link to this heading" > #< / a > < / h3 >
2023-11-30 04:41:56 +08:00
< p > A < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > is part of the computation graph of an < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > array< / span > < / code > . It
2025-04-04 04:25:24 +08:00
defines how to create output arrays given input arrays. Further, a
2024-04-12 08:33:33 +08:00
< code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > has methods to run on the CPU or GPU and for function
2025-04-04 04:25:24 +08:00
transformations such as < code class = "docutils literal notranslate" > < span class = "pre" > vjp< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > jvp< / span > < / code > . Let’ s go back to our example to be
2024-04-12 08:33:33 +08:00
more concrete:< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > class< / span > < span class = "w" > < / span > < span class = "nc" > Axpby< / span > < span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "k" > public< / span > < span class = "w" > < / span > < span class = "n" > Primitive< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > public< / span > < span class = "o" > :< / span >
< span class = "w" > < / span > < span class = "k" > explicit< / span > < span class = "w" > < / span > < span class = "n" > Axpby< / span > < span class = "p" > (< / span > < span class = "n" > Stream< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "p" > )< / span >
< span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "n" > Primitive< / span > < span class = "p" > (< / span > < span class = "n" > stream< / span > < span class = "p" > ),< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > (< / span > < span class = "n" > alpha< / span > < span class = "p" > ),< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > (< / span > < span class = "n" > beta< / span > < span class = "p" > ){};< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "cm" > /**< / span >
< span class = "cm" > * A primitive must know how to evaluate itself on the CPU/GPU< / span >
< span class = "cm" > * for the given inputs and populate the output array.< / span >
< span class = "cm" > *< / span >
2023-12-07 00:13:20 +08:00
< span class = "cm" > * To avoid unnecessary allocations, the evaluation function< / span >
2023-11-30 04:41:56 +08:00
< span class = "cm" > * is responsible for allocating space for the array.< / span >
2024-03-31 08:32:20 +08:00
< span class = "cm" > */< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "nf" > eval_cpu< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "nf" > eval_gpu< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "cm" > /** The Jacobian-vector product. */< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "n" > jvp< / span > < span class = "p" > (< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > primals< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > tangents< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "cm" > /** The vector-Jacobian product. */< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "n" > vjp< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > primals< / span > < span class = "p" > ,< / span >
2025-04-04 04:25:24 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > cotangents< / span > < span class = "p" > ,< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "cm" > /**< / span >
2023-12-18 05:23:03 +08:00
< span class = "cm" > * The primitive must know how to vectorize itself across< / span >
2023-11-30 04:41:56 +08:00
< span class = "cm" > * the given axes. The output is a pair containing the array< / span >
< span class = "cm" > * representing the vectorized computation and the axis which< / span >
< span class = "cm" > * corresponds to the output vectorized dimension.< / span >
2024-03-31 08:32:20 +08:00
< span class = "cm" > */< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > virtual< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > pair< / span > < span class = "o" > < < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > > < / span > < span class = "w" > < / span > < span class = "n" > vmap< / span > < span class = "p" > (< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > axes< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "cm" > /** Print the primitive. */< / span >
< span class = "w" > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "nf" > print< / span > < span class = "p" > (< / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > ostream< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > os< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "n" > os< / span > < span class = "w" > < / span > < span class = "o" > < < < / span > < span class = "w" > < / span > < span class = "s" > " Axpby" < / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "p" > }< / span >
< span class = "w" > < / span > < span class = "cm" > /** Equivalence check **/< / span >
< span class = "w" > < / span > < span class = "kt" > bool< / span > < span class = "w" > < / span > < span class = "nf" > is_equivalent< / span > < span class = "p" > (< / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > Primitive< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > other< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "k" > override< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "k" > private< / span > < span class = "o" > :< / span >
< span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ;< / span >
< span class = "p" > };< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > The < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Axpby< / span > < / code > class derives from the base < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > class. The
< code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Axpby< / span > < / code > treats < code class = "docutils literal notranslate" > < span class = "pre" > alpha< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > beta< / span > < / code > as parameters. It then provides
implementations of how the output array is produced given the inputs through
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_cpu()< / span > < / code > and < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_gpu()< / span > < / code > . It also provides rules
of transformations in < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::jvp()< / span > < / code > , < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::vjp()< / span > < / code > , and
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::vmap()< / span > < / code > .< / p >
2023-11-30 04:41:56 +08:00
< / section >
2024-04-12 08:33:33 +08:00
< section id = "using-the-primitive" >
< h3 > Using the Primitive< a class = "headerlink" href = "#using-the-primitive" title = "Link to this heading" > #< / a > < / h3 >
< p > Operations can use this < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > to add a new < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > array< / span > < / code > to the
computation graph. An < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > array< / span > < / code > can be constructed by providing its data
type, shape, the < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > that computes it, and the < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > array< / span > < / code >
inputs that are passed to the primitive.< / p >
< p > Let’ s reimplement our operation now in terms of our < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Axpby< / span > < / code > primitive.< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > array< / span > < span class = "w" > < / span > < span class = "nf" > axpby< / span > < span class = "p" > (< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Input array y< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for x< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "c1" > // Scaling factor for y< / span >
< span class = "w" > < / span > < span class = "n" > StreamOrDevice< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "w" > < / span > < span class = "cm" > /* = {} */< / span > < span class = "w" > < / span > < span class = "c1" > // Stream on which to schedule the operation< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Promote dtypes between x and y as needed< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > promoted_dtype< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > promote_types< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > (),< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ());< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Upcast to float32 for non-floating point inputs x and y< / span >
2025-03-21 06:37:22 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > out_dtype< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > issubdtype< / span > < span class = "p" > (< / span > < span class = "n" > promoted_dtype< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "o" > ?< / span > < span class = "w" > < / span > < span class = "n" > promoted_dtype< / span >
< span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "n" > promote_types< / span > < span class = "p" > (< / span > < span class = "n" > promoted_dtype< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > float32< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Cast x and y up to the determined dtype (on the same stream s)< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > x_casted< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out_dtype< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > y_casted< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out_dtype< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Broadcast the shapes of x and y (on the same stream s)< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > broadcasted_inputs< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > broadcast_arrays< / span > < span class = "p" > ({< / span > < span class = "n" > x_casted< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y_casted< / span > < span class = "p" > },< / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > out_shape< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > broadcasted_inputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ].< / span > < span class = "n" > shape< / span > < span class = "p" > ();< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Construct the array as the output of the Axpby primitive< / span >
< span class = "w" > < / span > < span class = "c1" > // with the broadcasted and upcasted arrays as inputs< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "cm" > /* const std::vector< int> & shape = */< / span > < span class = "w" > < / span > < span class = "n" > out_shape< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "cm" > /* Dtype dtype = */< / span > < span class = "w" > < / span > < span class = "n" > out_dtype< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "cm" > /* std::unique_ptr< Primitive> primitive = */< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > make_shared< / span > < span class = "o" > < < / span > < span class = "n" > Axpby< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > to_stream< / span > < span class = "p" > (< / span > < span class = "n" > s< / span > < span class = "p" > ),< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "p" > ),< / span >
< span class = "w" > < / span > < span class = "cm" > /* const std::vector< array> & inputs = */< / span > < span class = "w" > < / span > < span class = "n" > broadcasted_inputs< / span > < span class = "p" > );< / span >
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< p > This operation now handles the following:< / p >
< ol class = "arabic simple" >
2024-01-04 12:14:05 +08:00
< li > < p > Upcast inputs and resolve the output data type.< / p > < / li >
2023-11-30 04:41:56 +08:00
< li > < p > Broadcast the inputs and resolve the output shape.< / p > < / li >
< li > < p > Construct the primitive < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Axpby< / span > < / code > using the given stream, < code class = "docutils literal notranslate" > < span class = "pre" > alpha< / span > < / code > , and < code class = "docutils literal notranslate" > < span class = "pre" > beta< / span > < / code > .< / p > < / li >
< li > < p > Construct the output < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > array< / span > < / code > using the primitive and the inputs.< / p > < / li >
< / ol >
< / section >
< / section >
< section id = "implementing-the-primitive" >
2024-03-31 08:32:20 +08:00
< h2 > Implementing the Primitive< a class = "headerlink" href = "#implementing-the-primitive" title = "Link to this heading" > #< / a > < / h2 >
2024-04-12 08:33:33 +08:00
< p > No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_cpu()< / span > < / code > or
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_gpu()< / span > < / code > depending on the stream/device specified by the user.< / p >
2023-11-30 04:41:56 +08:00
< div class = "admonition warning" >
< p class = "admonition-title" > Warning< / p >
< p > When < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Primitive::eval_cpu()< / span > < / code > or < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Primitive::eval_gpu()< / span > < / code > are called,
no memory has been allocated for the output array. It falls on the implementation
2024-04-12 08:33:33 +08:00
of these functions to allocate memory as needed.< / p >
2023-11-30 04:41:56 +08:00
< / div >
2024-04-12 08:33:33 +08:00
< section id = "implementing-the-cpu-back-end" >
< h3 > Implementing the CPU Back-end< a class = "headerlink" href = "#implementing-the-cpu-back-end" title = "Link to this heading" > #< / a > < / h3 >
2025-03-21 06:37:22 +08:00
< p > Let’ s start by implementing < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_cpu()< / span > < / code > .< / p >
< p > The method will go over each element of the output array, find the
2023-11-30 04:41:56 +08:00
corresponding input elements of < code class = "docutils literal notranslate" > < span class = "pre" > x< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > y< / span > < / code > and perform the operation
2024-04-12 08:33:33 +08:00
point-wise. This is captured in the templated function < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > axpby_impl()< / span > < / code > .< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > template< / span > < span class = "w" > < / span > < span class = "o" > < < / span > < span class = "k" > typename< / span > < span class = "w" > < / span > < span class = "nc" > T< / span > < span class = "o" > > < / span >
< span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "n" > axpby_impl< / span > < span class = "p" > (< / span >
2025-03-21 06:37:22 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > array< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > Stream< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2025-03-25 04:24:41 +08:00
< span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > set_data< / span > < span class = "p" > (< / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > allocator< / span > < span class = "o" > ::< / span > < span class = "n" > malloc< / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > nbytes< / span > < span class = "p" > ()));< / span >
2025-03-21 06:37:22 +08:00
< span class = "w" > < / span > < span class = "c1" > // Get the CPU command encoder and register input and output arrays< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > encoder< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > cpu< / span > < span class = "o" > ::< / span > < span class = "n" > get_command_encoder< / span > < span class = "p" > (< / span > < span class = "n" > stream< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_input_array< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_input_array< / span > < span class = "p" > (< / span > < span class = "n" > y< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_output_array< / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "c1" > // Launch the CPU kernel< / span >
< span class = "w" > < / span > < span class = "n" > encoder< / span > < span class = "p" > .< / span > < span class = "n" > dispatch< / span > < span class = "p" > ([< / span > < span class = "n" > x_ptr< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > .< / span > < span class = "n" > data< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > y_ptr< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > .< / span > < span class = "n" > data< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > out_ptr< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > data< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > size< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > size< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > shape< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > x_strides< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > .< / span > < span class = "n" > strides< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > y_strides< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > .< / span > < span class = "n" > strides< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ]()< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Cast alpha and beta to the relevant types< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > T< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "k" > static_cast< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > alpha_< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > T< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "k" > static_cast< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > beta_< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
2024-01-04 12:14:05 +08:00
< span class = "w" > < / span > < span class = "c1" > // Do the element-wise operation for each output< / span >
2025-03-21 06:37:22 +08:00
< span class = "w" > < / span > < span class = "k" > for< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "kt" > size_t< / span > < span class = "w" > < / span > < span class = "n" > out_idx< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "mi" > 0< / span > < span class = "p" > ;< / span > < span class = "w" > < / span > < span class = "n" > out_idx< / span > < span class = "w" > < / span > < span class = "o" > < < / span > < span class = "w" > < / span > < span class = "n" > size< / span > < span class = "p" > ;< / span > < span class = "w" > < / span > < span class = "n" > out_idx< / span > < span class = "o" > ++< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "c1" > // Map linear indices to offsets in x and y< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > x_offset< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > elem_to_loc< / span > < span class = "p" > (< / span > < span class = "n" > out_idx< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > x_strides< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > y_offset< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > elem_to_loc< / span > < span class = "p" > (< / span > < span class = "n" > out_idx< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y_strides< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "c1" > // We allocate the output to be contiguous and regularly strided< / span >
< span class = "w" > < / span > < span class = "c1" > // (defaults to row major) and hence it doesn' t need additional mapping< / span >
< span class = "w" > < / span > < span class = "n" > out_ptr< / span > < span class = "p" > [< / span > < span class = "n" > out_idx< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "w" > < / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > x_ptr< / span > < span class = "p" > [< / span > < span class = "n" > x_offset< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > +< / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "w" > < / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > y_ptr< / span > < span class = "p" > [< / span > < span class = "n" > y_offset< / span > < span class = "p" > ];< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "p" > }< / span >
2025-03-21 06:37:22 +08:00
< span class = "w" > < / span > < span class = "p" > });< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for < code class = "docutils literal notranslate" > < span class = "pre" > float32< / span > < / code > , < code class = "docutils literal notranslate" > < span class = "pre" > float16< / span > < / code > , < code class = "docutils literal notranslate" > < span class = "pre" > bfloat16< / span > < / code > and
< code class = "docutils literal notranslate" > < span class = "pre" > complex64< / span > < / code > . We throw an error if we encounter an unexpected type.< / p >
2025-03-21 06:37:22 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "nf" > Axpby::eval_cpu< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ];< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ];< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ];< / span >
< span class = "w" > < / span > < span class = "c1" > // Dispatch to the correct dtype< / span >
< span class = "w" > < / span > < span class = "k" > if< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > axpby_impl< / span > < span class = "o" > < < / span > < span class = "kt" > float< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ());< / span >
< span class = "w" > < / span > < span class = "p" > }< / span > < span class = "w" > < / span > < span class = "k" > else< / span > < span class = "w" > < / span > < span class = "k" > if< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > float16< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > axpby_impl< / span > < span class = "o" > < < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > float16_t< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ());< / span >
< span class = "w" > < / span > < span class = "p" > }< / span > < span class = "w" > < / span > < span class = "k" > else< / span > < span class = "w" > < / span > < span class = "k" > if< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > bfloat16< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > axpby_impl< / span > < span class = "o" > < < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > bfloat16_t< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ());< / span >
< span class = "w" > < / span > < span class = "p" > }< / span > < span class = "w" > < / span > < span class = "k" > else< / span > < span class = "w" > < / span > < span class = "k" > if< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > complex64< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > axpby_impl< / span > < span class = "o" > < < / span > < span class = "n" > mx< / span > < span class = "o" > ::< / span > < span class = "n" > complex64_t< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ());< / span >
< span class = "w" > < / span > < span class = "p" > }< / span > < span class = "w" > < / span > < span class = "k" > else< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > throw< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > runtime_error< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > " Axpby is only supported for floating point types." < / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "p" > }< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > Just this much is enough to run the operation < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > axpby()< / span > < / code > on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
2023-11-30 04:41:56 +08:00
computation graphs that contain < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Axpby< / span > < / code > , you can stop implementing the
2025-03-21 06:37:22 +08:00
primitive here.< / p >
2023-11-30 04:41:56 +08:00
< / section >
2024-04-12 08:33:33 +08:00
< section id = "implementing-the-gpu-back-end" >
< h3 > Implementing the GPU Back-end< a class = "headerlink" href = "#implementing-the-gpu-back-end" title = "Link to this heading" > #< / a > < / h3 >
2023-11-30 04:41:56 +08:00
< p > Apple silicon devices address their GPUs using the < a class = "reference external" href = "https://developer.apple.com/documentation/metal?language=objc" > Metal< / a > shading language, and
2024-04-12 08:33:33 +08:00
GPU kernels in MLX are written using Metal.< / p >
2023-11-30 04:41:56 +08:00
< div class = "admonition note" >
< p class = "admonition-title" > Note< / p >
2024-04-12 08:33:33 +08:00
< p > Here are some helpful resources if you are new to Metal:< / p >
2023-11-30 04:41:56 +08:00
< ul class = "simple" >
< li > < p > A walkthrough of the metal compute pipeline: < a class = "reference external" href = "https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc" > Metal Example< / a > < / p > < / li >
< li > < p > Documentation for metal shading language: < a class = "reference external" href = "https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf" > Metal Specification< / a > < / p > < / li >
< li > < p > Using metal from C++: < a class = "reference external" href = "https://developer.apple.com/metal/cpp/" > Metal-cpp< / a > < / p > < / li >
< / ul >
< / div >
2024-04-12 08:33:33 +08:00
< p > Let’ s keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from < code class = "docutils literal notranslate" > < span class = "pre" > x< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > y< / span > < / code > , do the point-wise operation, and update its assigned
2023-11-30 04:41:56 +08:00
element in the output.< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > template< / span > < span class = "w" > < / span > < span class = "o" > < < / span > < span class = "k" > typename< / span > < span class = "w" > < / span > < span class = "nc" > T< / span > < span class = "o" > > < / span >
< span class = "p" > [[< / span > < span class = "n" > kernel< / span > < span class = "p" > ]]< / span > < span class = "w" > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "n" > axpby_general< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "n" > device< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > T< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > device< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > T< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > device< / span > < span class = "w" > < / span > < span class = "n" > T< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > alpha< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > beta< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 4< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > int< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 5< / span > < span class = "p" > )]],< / span >
2025-01-10 05:56:20 +08:00
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > int64_t< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > x_strides< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 6< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > int64_t< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > y_strides< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 7< / span > < span class = "p" > )]],< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > constant< / span > < span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "kt" > int< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > ndim< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 8< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > uint< / span > < span class = "w" > < / span > < span class = "n" > index< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > thread_position_in_grid< / span > < span class = "p" > ]])< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Convert linear indices to offsets in array< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > x_offset< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > elem_to_loc< / span > < span class = "p" > (< / span > < span class = "n" > index< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > x_strides< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > ndim< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > y_offset< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > elem_to_loc< / span > < span class = "p" > (< / span > < span class = "n" > index< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > shape< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > y_strides< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > ndim< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Do the operation and update the output< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > [< / span > < span class = "n" > index< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > =< / span >
< span class = "w" > < / span > < span class = "k" > static_cast< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > alpha< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "n" > x_offset< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > +< / span > < span class = "w" > < / span > < span class = "k" > static_cast< / span > < span class = "o" > < < / span > < span class = "n" > T< / span > < span class = "o" > > < / span > < span class = "p" > (< / span > < span class = "n" > beta< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "p" > [< / span > < span class = "n" > y_offset< / span > < span class = "p" > ];< / span >
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< p > We then need to instantiate this template for all floating point types and give
2024-04-12 08:33:33 +08:00
each instantiation a unique host name so we can identify it.< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > instantiate_kernel< / span > < span class = "p" > (< / span > < span class = "s" > " axpby_general_float32" < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > axpby_general< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "kt" > float< / span > < span class = "p" > )< / span >
< span class = "n" > instantiate_kernel< / span > < span class = "p" > (< / span > < span class = "s" > " axpby_general_float16" < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > axpby_general< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > float16_t< / span > < span class = "p" > )< / span >
< span class = "n" > instantiate_kernel< / span > < span class = "p" > (< / span > < span class = "s" > " axpby_general_bfloat16" < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > axpby_general< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > bfloat16_t< / span > < span class = "p" > )< / span >
< span class = "n" > instantiate_kernel< / span > < span class = "p" > (< / span > < span class = "s" > " axpby_general_complex64" < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > axpby_general< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > complex64_t< / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > Axpby::eval_gpu()< / span > < / code > as shown
2023-11-30 04:41:56 +08:00
below.< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "cm" > /** Evaluate primitive on GPU */< / span >
2024-04-12 08:33:33 +08:00
< span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "nf" > Axpby::eval_gpu< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Prepare inputs< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > assert< / span > < span class = "p" > (< / span > < span class = "n" > inputs< / span > < span class = "p" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "mi" > 2< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > x< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ];< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > y< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ];< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ];< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Each primitive carries the stream it should execute on< / span >
< span class = "w" > < / span > < span class = "c1" > // and each stream carries its device identifiers< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > s< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ();< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // We get the needed metal device using the stream< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > d< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > metal< / span > < span class = "o" > ::< / span > < span class = "n" > device< / span > < span class = "p" > (< / span > < span class = "n" > s< / span > < span class = "p" > .< / span > < span class = "n" > device< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Allocate output memory< / span >
2025-03-25 04:24:41 +08:00
< span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > set_data< / span > < span class = "p" > (< / span > < span class = "n" > allocator< / span > < span class = "o" > ::< / span > < span class = "n" > malloc< / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > nbytes< / span > < span class = "p" > ()));< / span >
2023-11-30 04:41:56 +08:00
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "c1" > // Resolve name of kernel< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > ostringstream< / span > < span class = "w" > < / span > < span class = "n" > kname< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "n" > kname< / span > < span class = "w" > < / span > < span class = "o" > < < < / span > < span class = "w" > < / span > < span class = "s" > " axpby_" < / span > < span class = "w" > < / span > < span class = "o" > < < < / span > < span class = "w" > < / span > < span class = "s" > " general_" < / span > < span class = "w" > < / span > < span class = "o" > < < < / span > < span class = "w" > < / span > < span class = "n" > type_to_name< / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
2024-08-11 00:24:35 +08:00
< span class = "w" > < / span > < span class = "c1" > // Make sure the metal library is available< / span >
< span class = "w" > < / span > < span class = "n" > d< / span > < span class = "p" > .< / span > < span class = "n" > register_library< / span > < span class = "p" > (< / span > < span class = "s" > " mlx_ext" < / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Make a kernel from this metal library< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > kernel< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > d< / span > < span class = "p" > .< / span > < span class = "n" > get_kernel< / span > < span class = "p" > (< / span > < span class = "n" > kname< / span > < span class = "p" > .< / span > < span class = "n" > str< / span > < span class = "p" > (),< / span > < span class = "w" > < / span > < span class = "s" > " mlx_ext" < / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Prepare to encode kernel< / span >
2024-05-21 00:40:17 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "o" > & < / span > < span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > d< / span > < span class = "p" > .< / span > < span class = "n" > get_command_encoder< / span > < span class = "p" > (< / span > < span class = "n" > s< / span > < span class = "p" > .< / span > < span class = "n" > index< / span > < span class = "p" > );< / span >
2024-11-23 04:24:16 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_compute_pipeline_state< / span > < span class = "p" > (< / span > < span class = "n" > kernel< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Kernel parameters are registered with buffer indices corresponding to< / span >
2024-01-04 12:14:05 +08:00
< span class = "w" > < / span > < span class = "c1" > // those in the kernel declaration at axpby.metal< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "kt" > int< / span > < span class = "w" > < / span > < span class = "n" > ndim< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > ndim< / span > < span class = "p" > ();< / span >
< span class = "w" > < / span > < span class = "kt" > size_t< / span > < span class = "w" > < / span > < span class = "n" > nelem< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > .< / span > < span class = "n" > size< / span > < span class = "p" > ();< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Encode input arrays to kernel< / span >
2024-05-21 00:40:17 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_input_array< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 0< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_input_array< / span > < span class = "p" > (< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Encode output arrays to kernel< / span >
2024-05-21 00:40:17 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_output_array< / span > < span class = "p" > (< / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 2< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Encode alpha and beta< / span >
2024-11-23 04:24:16 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_bytes< / span > < span class = "p" > (< / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 3< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_bytes< / span > < span class = "p" > (< / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 4< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Encode shape, strides and ndim< / span >
2024-11-23 04:24:16 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_vector_bytes< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > .< / span > < span class = "n" > shape< / span > < span class = "p" > (),< / span > < span class = "w" > < / span > < span class = "mi" > 5< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_vector_bytes< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > .< / span > < span class = "n" > strides< / span > < span class = "p" > (),< / span > < span class = "w" > < / span > < span class = "mi" > 6< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_bytes< / span > < span class = "p" > (< / span > < span class = "n" > y< / span > < span class = "p" > .< / span > < span class = "n" > strides< / span > < span class = "p" > (),< / span > < span class = "w" > < / span > < span class = "mi" > 7< / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > set_bytes< / span > < span class = "p" > (< / span > < span class = "n" > ndim< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 8< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // We launch 1 thread for each input and make sure that the number of< / span >
< span class = "w" > < / span > < span class = "c1" > // threads in any given threadgroup is not higher than the max allowed< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "kt" > size_t< / span > < span class = "w" > < / span > < span class = "n" > tgp_size< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > min< / span > < span class = "p" > (< / span > < span class = "n" > nelem< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > kernel< / span > < span class = "o" > -> < / span > < span class = "n" > maxTotalThreadsPerThreadgroup< / span > < span class = "p" > ());< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Fix the 3D size of each threadgroup (in terms of threads)< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > MTL< / span > < span class = "o" > ::< / span > < span class = "n" > Size< / span > < span class = "w" > < / span > < span class = "n" > group_dims< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > MTL< / span > < span class = "o" > ::< / span > < span class = "n" > Size< / span > < span class = "p" > (< / span > < span class = "n" > tgp_size< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Fix the 3D size of the launch grid (in terms of threads)< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > MTL< / span > < span class = "o" > ::< / span > < span class = "n" > Size< / span > < span class = "w" > < / span > < span class = "n" > grid_dims< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > MTL< / span > < span class = "o" > ::< / span > < span class = "n" > Size< / span > < span class = "p" > (< / span > < span class = "n" > nelem< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > );< / span >
2023-11-30 04:41:56 +08:00
2024-01-04 12:14:05 +08:00
< span class = "w" > < / span > < span class = "c1" > // Launch the grid with the given number of threads divided among< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // the given threadgroups< / span >
2024-11-23 04:24:16 +08:00
< span class = "w" > < / span > < span class = "n" > compute_encoder< / span > < span class = "p" > .< / span > < span class = "n" > dispatch_threads< / span > < span class = "p" > (< / span > < span class = "n" > grid_dims< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > group_dims< / span > < span class = "p" > );< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< p > We can now call the < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > axpby()< / span > < / code > operation on both the CPU and the GPU!< / p >
2024-04-12 08:33:33 +08:00
< p > A few things to note about MLX and Metal before moving on. MLX keeps track of
the active < code class = "docutils literal notranslate" > < span class = "pre" > command_buffer< / span > < / code > and the < code class = "docutils literal notranslate" > < span class = "pre" > MTLCommandBuffer< / span > < / code > to which it is
associated. We rely on < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > d.get_command_encoder()< / span > < / code > to give us the active
metal compute command encoder instead of building a new one and calling
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > compute_encoder-> end_encoding()< / span > < / code > at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.< / p >
2023-11-30 04:41:56 +08:00
< / section >
< section id = "primitive-transforms" >
2024-03-31 08:32:20 +08:00
< h3 > Primitive Transforms< a class = "headerlink" href = "#primitive-transforms" title = "Link to this heading" > #< / a > < / h3 >
2024-04-12 08:33:33 +08:00
< p > Next, let’ s add implementations for transformations in a < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > .
These transformations can be built on top of other operations, including the
one we just defined:< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "cm" > /** The Jacobian-vector product. */< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "n" > Axpby< / span > < span class = "o" > ::< / span > < span class = "n" > jvp< / span > < span class = "p" > (< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > primals< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > tangents< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Forward mode diff that pushes along the tangents< / span >
2025-04-04 04:25:24 +08:00
< span class = "w" > < / span > < span class = "c1" > // The jvp transform on the primitive can be built with ops< / span >
2024-01-04 12:14:05 +08:00
< span class = "w" > < / span > < span class = "c1" > // that are scheduled on the same stream as the primitive< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // If argnums = {0}, we only push along x in which case the< / span >
< span class = "w" > < / span > < span class = "c1" > // jvp is just the tangent scaled by alpha< / span >
< span class = "w" > < / span > < span class = "c1" > // Similarly, if argnums = {1}, the jvp is just the tangent< / span >
< span class = "w" > < / span > < span class = "c1" > // scaled by beta< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > if< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "n" > argnums< / span > < span class = "p" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > scale< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "mi" > 0< / span > < span class = "w" > < / span > < span class = "o" > ?< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > scale_arr< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "p" > (< / span > < span class = "n" > scale< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > tangents< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ].< / span > < span class = "n" > dtype< / span > < span class = "p" > ());< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "p" > {< / span > < span class = "n" > multiply< / span > < span class = "p" > (< / span > < span class = "n" > scale_arr< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > tangents< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ())};< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "p" > }< / span >
2025-04-04 04:25:24 +08:00
< span class = "w" > < / span > < span class = "c1" > // If argnums = {0, 1}, we take contributions from both< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // which gives us jvp = tangent_x * alpha + tangent_y * beta< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > else< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "p" > {< / span > < span class = "n" > axpby< / span > < span class = "p" > (< / span > < span class = "n" > tangents< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "w" > < / span > < span class = "n" > tangents< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ())};< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "p" > }< / span >
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "cm" > /** The vector-Jacobian product. */< / span >
< span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "n" > Axpby< / span > < span class = "o" > ::< / span > < span class = "n" > vjp< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > primals< / span > < span class = "p" > ,< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > cotangents< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "cm" > /* unused */< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2023-11-30 04:41:56 +08:00
< span class = "w" > < / span > < span class = "c1" > // Reverse mode diff< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "w" > < / span > < span class = "n" > vjps< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "k" > for< / span > < span class = "w" > < / span > < span class = "p" > (< / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > arg< / span > < span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "n" > argnums< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > scale< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > arg< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "mi" > 0< / span > < span class = "w" > < / span > < span class = "o" > ?< / span > < span class = "w" > < / span > < span class = "n" > alpha_< / span > < span class = "w" > < / span > < span class = "o" > :< / span > < span class = "w" > < / span > < span class = "n" > beta_< / span > < span class = "p" > ;< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > auto< / span > < span class = "w" > < / span > < span class = "n" > scale_arr< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > array< / span > < span class = "p" > (< / span > < span class = "n" > scale< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > cotangents< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ].< / span > < span class = "n" > dtype< / span > < span class = "p" > ());< / span >
< span class = "w" > < / span > < span class = "n" > vjps< / span > < span class = "p" > .< / span > < span class = "n" > push_back< / span > < span class = "p" > (< / span > < span class = "n" > multiply< / span > < span class = "p" > (< / span > < span class = "n" > scale_arr< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > cotangents< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "w" > < / span > < span class = "n" > stream< / span > < span class = "p" > ()));< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "p" > }< / span >
< span class = "w" > < / span > < span class = "k" > return< / span > < span class = "w" > < / span > < span class = "n" > vjps< / span > < span class = "p" > ;< / span >
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > Note, a transformation does not need to be fully defined to start using
the < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > Primitive< / span > < / code > .< / p >
2024-03-31 08:32:20 +08:00
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "cm" > /** Vectorize primitive along given axis */< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > pair< / span > < span class = "o" > < < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > < / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > > < / span > < span class = "w" > < / span > < span class = "n" > Axpby< / span > < span class = "o" > ::< / span > < span class = "n" > vmap< / span > < span class = "p" > (< / span >
2024-03-31 08:32:20 +08:00
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "n" > array< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > inputs< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > vector< / span > < span class = "o" > < < / span > < span class = "kt" > int< / span > < span class = "o" > > & < / span > < span class = "w" > < / span > < span class = "n" > axes< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "k" > throw< / span > < span class = "w" > < / span > < span class = "n" > std< / span > < span class = "o" > ::< / span > < span class = "n" > runtime_error< / span > < span class = "p" > (< / span > < span class = "s" > " [Axpby] vmap not implemented." < / span > < span class = "p" > );< / span >
2024-03-31 08:32:20 +08:00
< span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< / section >
< / section >
< section id = "building-and-binding" >
2024-03-31 08:32:20 +08:00
< h2 > Building and Binding< a class = "headerlink" href = "#building-and-binding" title = "Link to this heading" > #< / a > < / h2 >
2023-11-30 04:41:56 +08:00
< p > Let’ s look at the overall directory structure first.< / p >
< div class = "line-block" >
< div class = "line" > extensions< / div >
< div class = "line" > ├── axpby< / div >
< div class = "line" > │ ├── axpby.cpp< / div >
< div class = "line" > │ ├── axpby.h< / div >
< div class = "line" > │ └── axpby.metal< / div >
< div class = "line" > ├── mlx_sample_extensions< / div >
< div class = "line" > │ └── __init__.py< / div >
< div class = "line" > ├── bindings.cpp< / div >
< div class = "line" > ├── CMakeLists.txt< / div >
< div class = "line" > └── setup.py< / div >
< / div >
< ul class = "simple" >
< li > < p > < code class = "docutils literal notranslate" > < span class = "pre" > extensions/axpby/< / span > < / code > defines the C++ extension library< / p > < / li >
2024-01-04 12:14:05 +08:00
< li > < p > < code class = "docutils literal notranslate" > < span class = "pre" > extensions/mlx_sample_extensions< / span > < / code > sets out the structure for the
2024-04-12 08:33:33 +08:00
associated Python package< / p > < / li >
< li > < p > < code class = "docutils literal notranslate" > < span class = "pre" > extensions/bindings.cpp< / span > < / code > provides Python bindings for our operation< / p > < / li >
2023-11-30 04:41:56 +08:00
< li > < p > < code class = "docutils literal notranslate" > < span class = "pre" > extensions/CMakeLists.txt< / span > < / code > holds CMake rules to build the library and
2024-04-12 08:33:33 +08:00
Python bindings< / p > < / li >
2023-11-30 04:41:56 +08:00
< li > < p > < code class = "docutils literal notranslate" > < span class = "pre" > extensions/setup.py< / span > < / code > holds the < code class = "docutils literal notranslate" > < span class = "pre" > setuptools< / span > < / code > rules to build and install
2024-04-12 08:33:33 +08:00
the Python package< / p > < / li >
2023-11-30 04:41:56 +08:00
< / ul >
< section id = "binding-to-python" >
2024-03-31 08:32:20 +08:00
< h3 > Binding to Python< a class = "headerlink" href = "#binding-to-python" title = "Link to this heading" > #< / a > < / h3 >
2024-04-12 08:33:33 +08:00
< p > We use < a class = "reference external" href = "https://nanobind.readthedocs.io/en/latest/" > nanobind< / a > to build a Python API for the C++ library. Since bindings for
2024-02-18 05:25:37 +08:00
components such as < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.html#mlx.core.array" title = "mlx.core.array" > < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > mlx.core.array< / span > < / code > < / a > , < a class = "reference internal" href = "../python/_autosummary/mlx.core.stream.html#mlx.core.stream" title = "mlx.core.stream" > < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > mlx.core.stream< / span > < / code > < / a > , etc. are
2024-04-12 08:33:33 +08:00
already provided, adding our < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > axpby()< / span > < / code > is simple.< / p >
< div class = "highlight-C++ notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > NB_MODULE< / span > < span class = "p" > (< / span > < span class = "n" > _ext< / span > < span class = "p" > ,< / span > < span class = "w" > < / span > < span class = "n" > m< / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "n" > m< / span > < span class = "p" > .< / span > < span class = "n" > doc< / span > < span class = "p" > ()< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "s" > " Sample extension for MLX" < / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "n" > m< / span > < span class = "p" > .< / span > < span class = "n" > def< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > " axpby" < / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "o" > & < / span > < span class = "n" > axpby< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "s" > " x" < / span > < span class = "n" > _a< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "s" > " y" < / span > < span class = "n" > _a< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "s" > " alpha" < / span > < span class = "n" > _a< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "s" > " beta" < / span > < span class = "n" > _a< / span > < span class = "p" > ,< / span >
< span class = "w" > < / span > < span class = "n" > nb< / span > < span class = "o" > ::< / span > < span class = "n" > kw_only< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "s" > " stream" < / span > < span class = "n" > _a< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > nb< / span > < span class = "o" > ::< / span > < span class = "n" > none< / span > < span class = "p" > (),< / span >
< span class = "w" > < / span > < span class = "sa" > R< / span > < span class = "s" > " < / span > < span class = "dl" > (< / span >
< span class = "s" > Scale and sum two vectors element-wise< / span >
< span class = "s" > ``z = alpha * x + beta * y``< / span >
< span class = "s" > Follows numpy style broadcasting between ``x`` and ``y``< / span >
< span class = "s" > Inputs are upcasted to floats if needed< / span >
< span class = "s" > Args:< / span >
< span class = "s" > x (array): Input array.< / span >
< span class = "s" > y (array): Input array.< / span >
< span class = "s" > alpha (float): Scaling factor for ``x``.< / span >
< span class = "s" > beta (float): Scaling factor for ``y``.< / span >
< span class = "s" > Returns:< / span >
< span class = "s" > array: ``alpha * x + beta * y``< / span >
< span class = "s" > < / span > < span class = "dl" > )< / span > < span class = "s" > " < / span > < span class = "p" > );< / span >
< span class = "w" > < / span > < span class = "p" > }< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< p > Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.< / p >
< div class = "admonition warning" >
< p class = "admonition-title" > Warning< / p >
2024-04-12 08:33:33 +08:00
< p > < code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > mlx.core< / span > < / code > must be imported before importing
< code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > mlx_sample_extensions< / span > < / code > as defined by the nanobind module above to
2023-11-30 04:41:56 +08:00
ensure that the casters for < code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > mlx.core< / span > < / code > components like
< a class = "reference internal" href = "../python/_autosummary/mlx.core.array.html#mlx.core.array" title = "mlx.core.array" > < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > mlx.core.array< / span > < / code > < / a > are available.< / p >
< / div >
< / section >
< section id = "building-with-cmake" >
2024-03-31 08:32:20 +08:00
< span id = "id1" > < / span > < h3 > Building with CMake< a class = "headerlink" href = "#building-with-cmake" title = "Link to this heading" > #< / a > < / h3 >
2024-04-12 08:33:33 +08:00
< p > Building the C++ extension library only requires that you < code class = "docutils literal notranslate" > < span class = "pre" > find_package(MLX< / span >
< span class = "pre" > CONFIG)< / span > < / code > and then link it to your library.< / p >
2023-11-30 04:41:56 +08:00
< div class = "highlight-cmake notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "c" > # Add library< / span >
< span class = "nb" > add_library< / span > < span class = "p" > (< / span > < span class = "s" > mlx_ext< / span > < span class = "p" > )< / span >
< span class = "c" > # Add sources< / span >
< span class = "nb" > target_sources< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > mlx_ext< / span >
< span class = "w" > < / span > < span class = "s" > PUBLIC< / span >
< span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span class = "o" > }< / span > < span class = "s" > /axpby/axpby.cpp< / span >
< span class = "p" > )< / span >
< span class = "c" > # Add include headers< / span >
< span class = "nb" > target_include_directories< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > mlx_ext< / span > < span class = "w" > < / span > < span class = "s" > PUBLIC< / span > < span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span class = "o" > }< / span >
< span class = "p" > )< / span >
< span class = "c" > # Link to mlx< / span >
< span class = "nb" > target_link_libraries< / span > < span class = "p" > (< / span > < span class = "s" > mlx_ext< / span > < span class = "w" > < / span > < span class = "s" > PUBLIC< / span > < span class = "w" > < / span > < span class = "s" > mlx< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > We also need to build the attached Metal library. For convenience, we provide a
2023-11-30 04:41:56 +08:00
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > mlx_build_metallib()< / span > < / code > function that builds a < code class = "docutils literal notranslate" > < span class = "pre" > .metallib< / span > < / code > target given
sources, headers, destinations, etc. (defined in < code class = "docutils literal notranslate" > < span class = "pre" > cmake/extension.cmake< / span > < / code > and
automatically imported with MLX package).< / p >
2024-04-12 08:33:33 +08:00
< p > Here is what that looks like in practice:< / p >
2023-11-30 04:41:56 +08:00
< div class = "highlight-cmake notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "c" > # Build metallib< / span >
< span class = "nb" > if< / span > < span class = "p" > (< / span > < span class = "s" > MLX_BUILD_METAL< / span > < span class = "p" > )< / span >
< span class = "nb" > mlx_build_metallib< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > TARGET< / span > < span class = "w" > < / span > < span class = "s" > mlx_ext_metallib< / span >
< span class = "w" > < / span > < span class = "s" > TITLE< / span > < span class = "w" > < / span > < span class = "s" > mlx_ext< / span >
< span class = "w" > < / span > < span class = "s" > SOURCES< / span > < span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span class = "o" > }< / span > < span class = "s" > /axpby/axpby.metal< / span >
< span class = "w" > < / span > < span class = "s" > INCLUDE_DIRS< / span > < span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > PROJECT_SOURCE_DIR< / span > < span class = "o" > }< / span > < span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > MLX_INCLUDE_DIRS< / span > < span class = "o" > }< / span >
< span class = "w" > < / span > < span class = "s" > OUTPUT_DIRECTORY< / span > < span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > CMAKE_LIBRARY_OUTPUT_DIRECTORY< / span > < span class = "o" > }< / span >
< span class = "p" > )< / span >
< span class = "nb" > add_dependencies< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > mlx_ext< / span >
< span class = "w" > < / span > < span class = "s" > mlx_ext_metallib< / span >
< span class = "p" > )< / span >
< span class = "nb" > endif< / span > < span class = "p" > ()< / span >
< / pre > < / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > Finally, we build the < a class = "reference external" href = "https://nanobind.readthedocs.io/en/latest/" > nanobind< / a > bindings< / p >
< div class = "highlight-cmake notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nb" > nanobind_add_module< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "s" > _ext< / span >
< span class = "w" > < / span > < span class = "s" > NB_STATIC< / span > < span class = "w" > < / span > < span class = "s" > STABLE_ABI< / span > < span class = "w" > < / span > < span class = "s" > LTO< / span > < span class = "w" > < / span > < span class = "s" > NOMINSIZE< / span >
< span class = "w" > < / span > < span class = "s" > NB_DOMAIN< / span > < span class = "w" > < / span > < span class = "s" > mlx< / span >
< span class = "w" > < / span > < span class = "o" > ${< / span > < span class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span class = "o" > }< / span > < span class = "s" > /bindings.cpp< / span >
2023-11-30 04:41:56 +08:00
< span class = "p" > )< / span >
2024-04-12 08:33:33 +08:00
< span class = "nb" > target_link_libraries< / span > < span class = "p" > (< / span > < span class = "s" > _ext< / span > < span class = "w" > < / span > < span class = "s" > PRIVATE< / span > < span class = "w" > < / span > < span class = "s" > mlx_ext< / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
< span class = "nb" > if< / span > < span class = "p" > (< / span > < span class = "s" > BUILD_SHARED_LIBS< / span > < span class = "p" > )< / span >
2024-04-12 08:33:33 +08:00
< span class = "w" > < / span > < span class = "nb" > target_link_options< / span > < span class = "p" > (< / span > < span class = "s" > _ext< / span > < span class = "w" > < / span > < span class = "s" > PRIVATE< / span > < span class = "w" > < / span > < span class = "s" > -Wl,-rpath,@loader_path< / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
< span class = "nb" > endif< / span > < span class = "p" > ()< / span >
< / pre > < / div >
< / div >
< / section >
< section id = "building-with-setuptools" >
2024-03-31 08:32:20 +08:00
< h3 > Building with < code class = "docutils literal notranslate" > < span class = "pre" > setuptools< / span > < / code > < a class = "headerlink" href = "#building-with-setuptools" title = "Link to this heading" > #< / a > < / h3 >
2023-11-30 04:41:56 +08:00
< p > Once we have set out the CMake build rules as described above, we can use the
2024-04-12 08:33:33 +08:00
build utilities defined in < code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > mlx.extension< / span > < / code > :< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > from< / span > < span class = "w" > < / span > < span class = "nn" > mlx< / span > < span class = "w" > < / span > < span class = "kn" > import< / span > < span class = "n" > extension< / span >
< span class = "kn" > from< / span > < span class = "w" > < / span > < span class = "nn" > setuptools< / span > < span class = "w" > < / span > < span class = "kn" > import< / span > < span class = "n" > setup< / span >
2023-11-30 04:41:56 +08:00
< span class = "k" > if< / span > < span class = "vm" > __name__< / span > < span class = "o" > ==< / span > < span class = "s2" > " __main__" < / span > < span class = "p" > :< / span >
< span class = "n" > setup< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s2" > " mlx_sample_extensions" < / span > < span class = "p" > ,< / span >
< span class = "n" > version< / span > < span class = "o" > =< / span > < span class = "s2" > " 0.0.0" < / span > < span class = "p" > ,< / span >
< span class = "n" > description< / span > < span class = "o" > =< / span > < span class = "s2" > " Sample C++ and Metal extensions for MLX primitives." < / span > < span class = "p" > ,< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > ext_modules< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > extension< / span > < span class = "o" > .< / span > < span class = "n" > CMakeExtension< / span > < span class = "p" > (< / span > < span class = "s2" > " mlx_sample_extensions._ext" < / span > < span class = "p" > )],< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > cmdclass< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s2" > " build_ext" < / span > < span class = "p" > :< / span > < span class = "n" > extension< / span > < span class = "o" > .< / span > < span class = "n" > CMakeBuild< / span > < span class = "p" > },< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > packages< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " mlx_sample_extensions" < / span > < span class = "p" > ],< / span >
< span class = "n" > package_data< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s2" > " mlx_sample_extensions" < / span > < span class = "p" > :< / span > < span class = "p" > [< / span > < span class = "s2" > " *.so" < / span > < span class = "p" > ,< / span > < span class = "s2" > " *.dylib" < / span > < span class = "p" > ,< / span > < span class = "s2" > " *.metallib" < / span > < span class = "p" > ]},< / span >
< span class = "n" > extras_require< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s2" > " dev" < / span > < span class = "p" > :[]},< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > zip_safe< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > python_requires< / span > < span class = "o" > =< / span > < span class = "s2" > " > =3.8" < / span > < span class = "p" > ,< / span >
2023-11-30 04:41:56 +08:00
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< div class = "admonition note" >
< p class = "admonition-title" > Note< / p >
< p > We treat < code class = "docutils literal notranslate" > < span class = "pre" > extensions/mlx_sample_extensions< / span > < / code > as the package directory
even though it only contains a < code class = "docutils literal notranslate" > < span class = "pre" > __init__.py< / span > < / code > to ensure the following:< / p >
< ul class = "simple" >
2024-04-12 08:33:33 +08:00
< li > < p > < code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > mlx.core< / span > < / code > must be imported before importing < code class = "xref py py-mod docutils literal notranslate" > < span class = "pre" > _ext< / span > < / code > < / p > < / li >
2023-11-30 04:41:56 +08:00
< li > < p > The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed< / p > < / li >
< / ul >
< / div >
2024-04-12 08:33:33 +08:00
< p > To build the package, first install the build dependencies with < code class = "docutils literal notranslate" > < span class = "pre" > pip< / span > < span class = "pre" > install< / span >
< span class = "pre" > -r< / span > < span class = "pre" > requirements.txt< / span > < / code > . You can then build inplace for development using
2023-11-30 04:41:56 +08:00
< code class = "docutils literal notranslate" > < span class = "pre" > python< / span > < span class = "pre" > setup.py< / span > < span class = "pre" > build_ext< / span > < span class = "pre" > -j8< / span > < span class = "pre" > --inplace< / span > < / code > (in < code class = "docutils literal notranslate" > < span class = "pre" > extensions/< / span > < / code > )< / p >
2024-04-12 08:33:33 +08:00
< p > This results in the directory structure:< / p >
2023-11-30 04:41:56 +08:00
< div class = "line-block" >
< div class = "line" > extensions< / div >
< div class = "line" > ├── mlx_sample_extensions< / div >
< div class = "line" > │ ├── __init__.py< / div >
< div class = "line" > │ ├── libmlx_ext.dylib # C++ extension library< / div >
< div class = "line" > │ ├── mlx_ext.metallib # Metal library< / div >
2024-04-12 08:33:33 +08:00
< div class = "line" > │ └── _ext.cpython-3x-darwin.so # Python Binding< / div >
2023-11-30 04:41:56 +08:00
< div class = "line" > …< / div >
< / div >
2024-04-12 08:33:33 +08:00
< p > When you try to install using the command < code class = "docutils literal notranslate" > < span class = "pre" > python< / span > < span class = "pre" > -m< / span > < span class = "pre" > pip< / span > < span class = "pre" > install< / span > < span class = "pre" > .< / span > < / code > (in
< code class = "docutils literal notranslate" > < span class = "pre" > extensions/< / span > < / code > ), the package will be installed with the same structure as
< code class = "docutils literal notranslate" > < span class = "pre" > extensions/mlx_sample_extensions< / span > < / code > and the C++ and Metal library will be
copied along with the Python binding since they are specified as
< code class = "docutils literal notranslate" > < span class = "pre" > package_data< / span > < / code > .< / p >
2023-11-30 04:41:56 +08:00
< / section >
< / section >
< section id = "usage" >
2024-03-31 08:32:20 +08:00
< h2 > Usage< a class = "headerlink" href = "#usage" title = "Link to this heading" > #< / a > < / h2 >
2023-11-30 04:41:56 +08:00
< p > After installing the extension as described above, you should be able to simply
2024-04-12 08:33:33 +08:00
import the Python package and play with it as you would any other MLX operation.< / p >
< p > Let’ s look at a simple script and its results:< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "w" > < / span > < span class = "nn" > mlx.core< / span > < span class = "w" > < / span > < span class = "k" > as< / span > < span class = "w" > < / span > < span class = "nn" > mx< / span >
< span class = "kn" > from< / span > < span class = "w" > < / span > < span class = "nn" > mlx_sample_extensions< / span > < span class = "w" > < / span > < span class = "kn" > import< / span > < span class = "n" > axpby< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > ones< / span > < span class = "p" > ((< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 4< / span > < span class = "p" > ))< / span >
< span class = "n" > b< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > ones< / span > < span class = "p" > ((< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 4< / span > < span class = "p" > ))< / span >
< span class = "n" > c< / span > < span class = "o" > =< / span > < span class = "n" > axpby< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > b< / span > < span class = "p" > ,< / span > < span class = "mf" > 4.0< / span > < span class = "p" > ,< / span > < span class = "mf" > 2.0< / span > < span class = "p" > ,< / span > < span class = "n" > stream< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > cpu< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s2" > " c shape: < / span > < span class = "si" > {< / span > < span class = "n" > c< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s2" > " < / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s2" > " c dtype: < / span > < span class = "si" > {< / span > < span class = "n" > c< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "si" > }< / span > < span class = "s2" > " < / span > < span class = "p" > )< / span >
2025-04-04 04:25:24 +08:00
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s2" > " c is correct: < / span > < span class = "si" > {< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > all< / span > < span class = "p" > (< / span > < span class = "n" > c< / span > < span class = "w" > < / span > < span class = "o" > ==< / span > < span class = "w" > < / span > < span class = "mf" > 6.0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span > < span class = "si" > }< / span > < span class = "s2" > " < / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< p > Output:< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > c< / span > < span class = "n" > shape< / span > < span class = "p" > :< / span > < span class = "p" > [< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 4< / span > < span class = "p" > ]< / span >
< span class = "n" > c< / span > < span class = "n" > dtype< / span > < span class = "p" > :< / span > < span class = "n" > float32< / span >
2025-04-04 04:25:24 +08:00
< span class = "n" > c< / span > < span class = "ow" > is< / span > < span class = "n" > correct< / span > < span class = "p" > :< / span > < span class = "kc" > True< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
< section id = "results" >
2024-03-31 08:32:20 +08:00
< h3 > Results< a class = "headerlink" href = "#results" title = "Link to this heading" > #< / a > < / h3 >
2023-11-30 04:41:56 +08:00
< p > Let’ s run a quick benchmark and see how our new < code class = "docutils literal notranslate" > < span class = "pre" > axpby< / span > < / code > operation compares
2025-03-21 06:37:22 +08:00
with the naive < code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > simple_axpby()< / span > < / code > we first defined.< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "w" > < / span > < span class = "nn" > mlx.core< / span > < span class = "w" > < / span > < span class = "k" > as< / span > < span class = "w" > < / span > < span class = "nn" > mx< / span >
< span class = "kn" > from< / span > < span class = "w" > < / span > < span class = "nn" > mlx_sample_extensions< / span > < span class = "w" > < / span > < span class = "kn" > import< / span > < span class = "n" > axpby< / span >
< span class = "kn" > import< / span > < span class = "w" > < / span > < span class = "nn" > time< / span >
2023-11-30 04:41:56 +08:00
2025-01-10 05:56:20 +08:00
< span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > simple_axpby< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ,< / span > < span class = "n" > alpha< / span > < span class = "p" > :< / span > < span class = "nb" > float< / span > < span class = "p" > ,< / span > < span class = "n" > beta< / span > < span class = "p" > :< / span > < span class = "nb" > float< / span > < span class = "p" > )< / span > < span class = "o" > -> < / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > :< / span >
2023-11-30 04:41:56 +08:00
< span class = "k" > return< / span > < span class = "n" > alpha< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > +< / span > < span class = "n" > beta< / span > < span class = "o" > *< / span > < span class = "n" > y< / span >
2025-03-21 06:37:22 +08:00
< span class = "n" > M< / span > < span class = "o" > =< / span > < span class = "mi" > 4096< / span >
< span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "mi" > 4096< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > random< / span > < span class = "o" > .< / span > < span class = "n" > normal< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ))< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > random< / span > < span class = "o" > .< / span > < span class = "n" > normal< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ))< / span >
< span class = "n" > alpha< / span > < span class = "o" > =< / span > < span class = "mf" > 4.0< / span >
< span class = "n" > beta< / span > < span class = "o" > =< / span > < span class = "mf" > 2.0< / span >
2024-04-12 08:33:33 +08:00
< span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
2025-01-10 05:56:20 +08:00
< span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > bench< / span > < span class = "p" > (< / span > < span class = "n" > f< / span > < span class = "p" > ):< / span >
2023-11-30 04:41:56 +08:00
< span class = "c1" > # Warm up< / span >
2025-03-21 06:37:22 +08:00
< span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 5< / span > < span class = "p" > ):< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > z< / span > < span class = "o" > =< / span > < span class = "n" > f< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "n" > beta< / span > < span class = "p" > )< / span >
< span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > (< / span > < span class = "n" > z< / span > < span class = "p" > )< / span >
< span class = "c1" > # Timed run< / span >
< span class = "n" > s< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
2025-03-21 06:37:22 +08:00
< span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 100< / span > < span class = "p" > ):< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > z< / span > < span class = "o" > =< / span > < span class = "n" > f< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "n" > alpha< / span > < span class = "p" > ,< / span > < span class = "n" > beta< / span > < span class = "p" > )< / span >
< span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > (< / span > < span class = "n" > z< / span > < span class = "p" > )< / span >
< span class = "n" > e< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
2025-03-21 06:37:22 +08:00
< span class = "k" > return< / span > < span class = "mi" > 1000< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > e< / span > < span class = "o" > -< / span > < span class = "n" > s< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "mi" > 100< / span >
2023-11-30 04:41:56 +08:00
< span class = "n" > simple_time< / span > < span class = "o" > =< / span > < span class = "n" > bench< / span > < span class = "p" > (< / span > < span class = "n" > simple_axpby< / span > < span class = "p" > )< / span >
< span class = "n" > custom_time< / span > < span class = "o" > =< / span > < span class = "n" > bench< / span > < span class = "p" > (< / span > < span class = "n" > axpby< / span > < span class = "p" > )< / span >
2025-03-21 06:37:22 +08:00
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s2" > " Simple axpby: < / span > < span class = "si" > {< / span > < span class = "n" > simple_time< / span > < span class = "si" > :< / span > < span class = "s2" > .3f< / span > < span class = "si" > }< / span > < span class = "s2" > ms | Custom axpby: < / span > < span class = "si" > {< / span > < span class = "n" > custom_time< / span > < span class = "si" > :< / span > < span class = "s2" > .3f< / span > < span class = "si" > }< / span > < span class = "s2" > ms" < / span > < span class = "p" > )< / span >
2023-11-30 04:41:56 +08:00
< / pre > < / div >
< / div >
2025-03-21 06:37:22 +08:00
< p > The results are < code class = "docutils literal notranslate" > < span class = "pre" > Simple< / span > < span class = "pre" > axpby:< / span > < span class = "pre" > 1.559< / span > < span class = "pre" > ms< / span > < span class = "pre" > |< / span > < span class = "pre" > Custom< / span > < span class = "pre" > axpby:< / span > < span class = "pre" > 0.774< / span > < span class = "pre" > ms< / span > < / code > . We see
2024-04-12 08:33:33 +08:00
modest improvements right away!< / p >
2024-02-09 04:44:23 +08:00
< p > This operation is now good to be used to build other operations, in
< a class = "reference internal" href = "../python/nn/module.html#mlx.nn.Module" title = "mlx.nn.Module" > < code class = "xref py py-class docutils literal notranslate" > < span class = "pre" > mlx.nn.Module< / span > < / code > < / a > calls, and also as a part of graph transformations like
2024-04-12 08:33:33 +08:00
< code class = "xref py py-meth docutils literal notranslate" > < span class = "pre" > grad()< / span > < / code > .< / p >
2023-11-30 04:41:56 +08:00
< / section >
< / section >
< section id = "scripts" >
2024-03-31 08:32:20 +08:00
< h2 > Scripts< a class = "headerlink" href = "#scripts" title = "Link to this heading" > #< / a > < / h2 >
2023-11-30 04:41:56 +08:00
< div class = "admonition-download-the-code admonition" >
< p class = "admonition-title" > Download the code< / p >
2024-04-12 08:33:33 +08:00
< p > The full example code is available in < a class = "reference external" href = "https://github.com/ml-explore/mlx/tree/main/examples/extensions/" > mlx< / a > .< / p >
2023-11-30 04:41:56 +08:00
< / div >
< / section >
< / section >
2023-12-06 04:10:03 +08:00
< / article >
2024-10-15 23:12:17 +08:00
< footer class = "prev-next-footer d-print-none" >
2023-12-06 04:10:03 +08:00
< div class = "prev-next-area" >
< a class = "left-prev"
href="../cpp/ops.html"
title="previous page">
< i class = "fa-solid fa-angle-left" > < / i >
< div class = "prev-next-info" >
< p class = "prev-next-subtitle" > previous< / p >
< p class = "prev-next-title" > Operations< / p >
< / div >
< / a >
2024-03-31 08:32:20 +08:00
< a class = "right-next"
href="metal_debugger.html"
title="next page">
< div class = "prev-next-info" >
< p class = "prev-next-subtitle" > next< / p >
< p class = "prev-next-title" > Metal Debugger< / p >
< / div >
< i class = "fa-solid fa-angle-right" > < / i >
< / a >
2023-12-06 04:10:03 +08:00
< / div >
< / footer >
< / div >
2025-03-06 05:30:09 +08:00
< div class = "bd-sidebar-secondary bd-toc" > < div class = "sidebar-secondary-items sidebar-secondary__inner" >
2023-12-06 04:10:03 +08:00
2024-10-15 23:12:17 +08:00
2023-12-06 04:10:03 +08:00
< div class = "sidebar-secondary-item" >
< div class = "page-toc tocsection onthispage" >
< i class = "fa-solid fa-list" > < / i > Contents
< / div >
< nav class = "bd-toc-nav page-toc" >
< ul class = "visible nav section-nav flex-column" >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#introducing-the-example" > Introducing the Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#operations-and-primitives" > Operations and Primitives< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#operations" > Operations< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#primitives" > Primitives< / a > < / li >
2024-04-12 08:33:33 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#using-the-primitive" > Using the Primitive< / a > < / li >
2023-12-06 04:10:03 +08:00
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-primitive" > Implementing the Primitive< / a > < ul class = "visible nav section-nav flex-column" >
2024-04-12 08:33:33 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-cpu-back-end" > Implementing the CPU Back-end< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#implementing-the-gpu-back-end" > Implementing the GPU Back-end< / a > < / li >
2023-12-06 04:10:03 +08:00
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#primitive-transforms" > Primitive Transforms< / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-and-binding" > Building and Binding< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#binding-to-python" > Binding to Python< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-with-cmake" > Building with CMake< / a > < / li >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#building-with-setuptools" > Building with < code class = "docutils literal notranslate" > < span class = "pre" > setuptools< / span > < / code > < / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#usage" > Usage< / a > < ul class = "visible nav section-nav flex-column" >
< li class = "toc-h3 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#results" > Results< / a > < / li >
< / ul >
< / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#scripts" > Scripts< / a > < / li >
< / ul >
< / nav > < / div >
2023-11-30 04:41:56 +08:00
2023-12-06 04:10:03 +08:00
< / div > < / div >
< / div >
< footer class = "bd-footer-content" >
< div class = "bd-footer-content__inner container" >
< div class = "footer-item" >
< p class = "component-author" >
By MLX Contributors
< / p >
2023-11-30 04:41:56 +08:00
< / div >
2023-12-06 04:10:03 +08:00
< div class = "footer-item" >
2023-11-30 04:41:56 +08:00
2023-12-06 04:10:03 +08:00
< p class = "copyright" >
2025-06-03 07:29:32 +08:00
© Copyright 2023, Apple.
2023-12-06 04:10:03 +08:00
< br / >
< / p >
2023-11-30 04:41:56 +08:00
< / div >
2023-12-06 04:10:03 +08:00
< div class = "footer-item" >
< / div >
< div class = "footer-item" >
< / div >
< / div >
< / footer >
< / main >
< / div >
< / div >
<!-- Scripts loaded after <body> so the DOM is not blocked -->
2025-03-06 05:30:09 +08:00
< script src = "../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" > < / script >
< script src = "../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" > < / script >
2023-11-30 04:41:56 +08:00
2023-12-06 04:10:03 +08:00
< footer class = "bd-footer" >
< / footer >
< / body >
2023-11-30 04:41:56 +08:00
< / html >