2024-08-24 03:14:53 +08:00
<!DOCTYPE html>
< html lang = "en" data-content_root = "../" >
< head >
< meta charset = "utf-8" / >
2024-10-15 23:12:17 +08:00
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / > < meta name = "viewport" content = "width=device-width, initial-scale=1" / >
2024-08-24 03:14:53 +08:00
2025-06-04 09:03:47 +08:00
< title > Custom Metal Kernels — MLX 0.26.1 documentation< / title >
2024-08-24 03:14:53 +08:00
< script data-cfasync = "false" >
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
2024-10-15 23:12:17 +08:00
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
2024-08-24 03:14:53 +08:00
< / script >
<!-- Loaded before other Sphinx assets -->
2025-03-06 05:30:09 +08:00
< link href = "../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link href = "../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel = "stylesheet" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" / >
< link rel = "preload" as = "font" type = "font/woff2" crossorigin href = "../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" / >
2024-08-24 03:14:53 +08:00
2025-01-10 05:56:20 +08:00
< link rel = "stylesheet" type = "text/css" href = "../_static/pygments.css?v=03e43079" / >
2025-03-06 05:30:09 +08:00
< link rel = "stylesheet" type = "text/css" href = "../_static/styles/sphinx-book-theme.css?v=eba8b062" / >
2024-08-24 03:14:53 +08:00
<!-- Pre - loaded scripts that we'll load fully later -->
2025-03-06 05:30:09 +08:00
< link rel = "preload" as = "script" href = "../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" / >
< link rel = "preload" as = "script" href = "../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" / >
< script src = "../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b" > < / script >
2024-08-24 03:14:53 +08:00
2025-06-04 09:03:47 +08:00
< script src = "../_static/documentation_options.js?v=3724ff34" > < / script >
2024-10-15 23:12:17 +08:00
< script src = "../_static/doctools.js?v=9a2dae69" > < / script >
2024-08-24 03:14:53 +08:00
< script src = "../_static/sphinx_highlight.js?v=dc90522c" > < / script >
2024-10-15 23:12:17 +08:00
< script src = "../_static/scripts/sphinx-book-theme.js?v=887ef09a" > < / script >
2024-08-24 03:14:53 +08:00
< script > DOCUMENTATION _OPTIONS . pagename = 'dev/custom_metal_kernels' ; < / script >
2024-10-31 11:00:19 +08:00
< link rel = "icon" href = "../_static/mlx_logo.png" / >
2024-08-24 03:14:53 +08:00
< link rel = "index" title = "Index" href = "../genindex.html" / >
< link rel = "search" title = "Search" href = "../search.html" / >
2025-01-10 05:56:20 +08:00
< link rel = "next" title = "Using MLX in C++" href = "mlx_in_cpp.html" / >
2024-08-24 03:14:53 +08:00
< link rel = "prev" title = "Metal Debugger" href = "metal_debugger.html" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1" / >
< meta name = "docsearch:language" content = "en" / >
< / head >
< body data-bs-spy = "scroll" data-bs-target = ".bd-toc-nav" data-offset = "180" data-bs-root-margin = "0px 0px -60%" data-default-mode = "" >
2024-10-15 23:12:17 +08:00
< div id = "pst-skip-link" class = "skip-link d-print-none" > < a href = "#main-content" > Skip to main content< / a > < / div >
2024-08-24 03:14:53 +08:00
< div id = "pst-scroll-pixel-helper" > < / div >
< button type = "button" class = "btn rounded-pill" id = "pst-back-to-top" >
2024-10-15 23:12:17 +08:00
< i class = "fa-solid fa-arrow-up" > < / i > Back to top< / button >
2024-08-24 03:14:53 +08:00
2025-03-06 05:30:09 +08:00
< input type = "checkbox"
class="sidebar-toggle"
id="pst-primary-sidebar-checkbox"/>
< label class = "overlay overlay-primary" for = "pst-primary-sidebar-checkbox" > < / label >
< input type = "checkbox"
class="sidebar-toggle"
id="pst-secondary-sidebar-checkbox"/>
< label class = "overlay overlay-secondary" for = "pst-secondary-sidebar-checkbox" > < / label >
< div class = "search-button__wrapper" >
< div class = "search-button__overlay" > < / div >
< div class = "search-button__search-container" >
2024-08-24 03:14:53 +08:00
< form class = "bd-search d-flex align-items-center"
action="../search.html"
method="get">
< i class = "fa-solid fa-magnifying-glass" > < / i >
< input type = "search"
class="form-control"
name="q"
2025-03-06 05:30:09 +08:00
id="search-input"
2024-08-24 03:14:53 +08:00
placeholder="Search..."
aria-label="Search..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
< span class = "search-button__kbd-shortcut" > < kbd class = "kbd-shortcut__modifier" > Ctrl< / kbd > +< kbd > K< / kbd > < / span >
2025-03-06 05:30:09 +08:00
< / form > < / div >
< / div >
2024-10-15 23:12:17 +08:00
< div class = "pst-async-banner-revealer d-none" >
< aside id = "bd-header-version-warning" class = "d-none d-print-none" aria-label = "Version warning" > < / aside >
< / div >
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
< header class = "bd-header navbar navbar-expand-lg bd-navbar d-print-none" >
< / header >
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
< div class = "bd-container" >
< div class = "bd-container__inner bd-page-width" >
2024-10-15 23:12:17 +08:00
2025-03-06 05:30:09 +08:00
< div class = "bd-sidebar-primary bd-sidebar" >
2024-08-24 03:14:53 +08:00
< div class = "sidebar-header-items sidebar-primary__section" >
< / div >
< div class = "sidebar-primary-items__start sidebar-primary__section" >
< div class = "sidebar-primary-item" >
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
< a class = "navbar-brand logo" href = "../index.html" >
2025-06-04 09:03:47 +08:00
< img src = "../_static/mlx_logo.png" class = "logo__image only-light" alt = "MLX 0.26.1 documentation - Home" / >
< script > document . write ( ` <img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/> ` ) ; < / script >
2024-08-24 03:14:53 +08:00
< / a > < / div >
< div class = "sidebar-primary-item" >
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn search-button-field search-button__button" title = "Search" aria-label = "Search" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
< i class = "fa-solid fa-magnifying-glass" > < / i >
< span class = "search-button__default-text" > Search< / span >
< span class = "search-button__kbd-shortcut" > < kbd class = "kbd-shortcut__modifier" > Ctrl< / kbd > +< kbd class = "kbd-shortcut__modifier" > K< / kbd > < / span >
< / button >
`);
< / script > < / div >
2024-08-24 03:14:53 +08:00
< div class = "sidebar-primary-item" > < nav class = "bd-links bd-docs-nav" aria-label = "Main" >
< div class = "bd-toc-item navbar-nav active" >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Install< / span > < / p >
< ul class = "nav bd-sidenav" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../install.html" > Build and Install< / a > < / li >
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Usage< / span > < / p >
< ul class = "nav bd-sidenav" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/quick_start.html" > Quick Start Guide< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/lazy_evaluation.html" > Lazy Evaluation< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/unified_memory.html" > Unified Memory< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/indexing.html" > Indexing Arrays< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/saving_and_loading.html" > Saving and Loading Arrays< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/function_transforms.html" > Function Transforms< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/compile.html" > Compilation< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/numpy.html" > Conversion to NumPy and Other Frameworks< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/distributed.html" > Distributed Communication< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/using_streams.html" > Using Streams< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../usage/export.html" > Exporting Functions< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Examples< / span > < / p >
< ul class = "nav bd-sidenav" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/linear_regression.html" > Linear Regression< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/mlp.html" > Multi-Layer Perceptron< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../examples/llama-inference.html" > LLM inference< / a > < / li >
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Python API Reference< / span > < / p >
< ul class = "nav bd-sidenav" >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/array.html" > Array< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.html" > mlx.core.array< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.astype.html" > mlx.core.array.astype< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.at.html" > mlx.core.array.at< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.item.html" > mlx.core.array.item< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.tolist.html" > mlx.core.array.tolist< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.dtype.html" > mlx.core.array.dtype< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.itemsize.html" > mlx.core.array.itemsize< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.nbytes.html" > mlx.core.array.nbytes< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.ndim.html" > mlx.core.array.ndim< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.shape.html" > mlx.core.array.shape< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.size.html" > mlx.core.array.size< / a > < / li >
2025-06-03 07:29:32 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.real.html" > mlx.core.array.real< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.imag.html" > mlx.core.array.imag< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.abs.html" > mlx.core.array.abs< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.all.html" > mlx.core.array.all< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.any.html" > mlx.core.array.any< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.argmax.html" > mlx.core.array.argmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.argmin.html" > mlx.core.array.argmin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.conj.html" > mlx.core.array.conj< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cos.html" > mlx.core.array.cos< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cummax.html" > mlx.core.array.cummax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cummin.html" > mlx.core.array.cummin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cumprod.html" > mlx.core.array.cumprod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.cumsum.html" > mlx.core.array.cumsum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.diag.html" > mlx.core.array.diag< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.diagonal.html" > mlx.core.array.diagonal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.exp.html" > mlx.core.array.exp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.flatten.html" > mlx.core.array.flatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log.html" > mlx.core.array.log< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log10.html" > mlx.core.array.log10< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log1p.html" > mlx.core.array.log1p< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.log2.html" > mlx.core.array.log2< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.logcumsumexp.html" > mlx.core.array.logcumsumexp< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.logsumexp.html" > mlx.core.array.logsumexp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.max.html" > mlx.core.array.max< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.mean.html" > mlx.core.array.mean< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.min.html" > mlx.core.array.min< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.moveaxis.html" > mlx.core.array.moveaxis< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.prod.html" > mlx.core.array.prod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.reciprocal.html" > mlx.core.array.reciprocal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.reshape.html" > mlx.core.array.reshape< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.round.html" > mlx.core.array.round< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.rsqrt.html" > mlx.core.array.rsqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sin.html" > mlx.core.array.sin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.split.html" > mlx.core.array.split< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sqrt.html" > mlx.core.array.sqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.square.html" > mlx.core.array.square< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.squeeze.html" > mlx.core.array.squeeze< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.std.html" > mlx.core.array.std< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.sum.html" > mlx.core.array.sum< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.swapaxes.html" > mlx.core.array.swapaxes< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.transpose.html" > mlx.core.array.transpose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.T.html" > mlx.core.array.T< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.var.html" > mlx.core.array.var< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array.view.html" > mlx.core.array.view< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/data_types.html" > Data Types< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.Dtype.html" > mlx.core.Dtype< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.DtypeCategory.html" > mlx.core.DtypeCategory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.issubdtype.html" > mlx.core.issubdtype< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.finfo.html" > mlx.core.finfo< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/devices_and_streams.html" > Devices and Streams< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.Device.html" > mlx.core.Device< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/stream_class.html" > mlx.core.Stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.default_device.html" > mlx.core.default_device< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_default_device.html" > mlx.core.set_default_device< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.default_stream.html" > mlx.core.default_stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.new_stream.html" > mlx.core.new_stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_default_stream.html" > mlx.core.set_default_stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stream.html" > mlx.core.stream< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.synchronize.html" > mlx.core.synchronize< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/export.html" > Export Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.export_function.html" > mlx.core.export_function< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.import_function.html" > mlx.core.import_function< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.exporter.html" > mlx.core.exporter< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.export_to_dot.html" > mlx.core.export_to_dot< / a > < / li >
< / ul >
< / details > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/ops.html" > Operations< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.abs.html" > mlx.core.abs< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.add.html" > mlx.core.add< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.addmm.html" > mlx.core.addmm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.all.html" > mlx.core.all< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.allclose.html" > mlx.core.allclose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.any.html" > mlx.core.any< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arange.html" > mlx.core.arange< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arccos.html" > mlx.core.arccos< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arccosh.html" > mlx.core.arccosh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arcsin.html" > mlx.core.arcsin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arcsinh.html" > mlx.core.arcsinh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctan.html" > mlx.core.arctan< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctan2.html" > mlx.core.arctan2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.arctanh.html" > mlx.core.arctanh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argmax.html" > mlx.core.argmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argmin.html" > mlx.core.argmin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argpartition.html" > mlx.core.argpartition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.argsort.html" > mlx.core.argsort< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.array_equal.html" > mlx.core.array_equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.as_strided.html" > mlx.core.as_strided< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_1d.html" > mlx.core.atleast_1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_2d.html" > mlx.core.atleast_2d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.atleast_3d.html" > mlx.core.atleast_3d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_and.html" > mlx.core.bitwise_and< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_invert.html" > mlx.core.bitwise_invert< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_or.html" > mlx.core.bitwise_or< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.bitwise_xor.html" > mlx.core.bitwise_xor< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.block_masked_mm.html" > mlx.core.block_masked_mm< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.broadcast_arrays.html" > mlx.core.broadcast_arrays< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.broadcast_to.html" > mlx.core.broadcast_to< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ceil.html" > mlx.core.ceil< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.clip.html" > mlx.core.clip< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.concatenate.html" > mlx.core.concatenate< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.contiguous.html" > mlx.core.contiguous< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conj.html" > mlx.core.conj< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conjugate.html" > mlx.core.conjugate< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.convolve.html" > mlx.core.convolve< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv1d.html" > mlx.core.conv1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv2d.html" > mlx.core.conv2d< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv3d.html" > mlx.core.conv3d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose1d.html" > mlx.core.conv_transpose1d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose2d.html" > mlx.core.conv_transpose2d< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_transpose3d.html" > mlx.core.conv_transpose3d< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.conv_general.html" > mlx.core.conv_general< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cos.html" > mlx.core.cos< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cosh.html" > mlx.core.cosh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cummax.html" > mlx.core.cummax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cummin.html" > mlx.core.cummin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cumprod.html" > mlx.core.cumprod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.cumsum.html" > mlx.core.cumsum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.degrees.html" > mlx.core.degrees< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.dequantize.html" > mlx.core.dequantize< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.diag.html" > mlx.core.diag< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.diagonal.html" > mlx.core.diagonal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.divide.html" > mlx.core.divide< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.divmod.html" > mlx.core.divmod< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.einsum.html" > mlx.core.einsum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.einsum_path.html" > mlx.core.einsum_path< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.equal.html" > mlx.core.equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.erf.html" > mlx.core.erf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.erfinv.html" > mlx.core.erfinv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.exp.html" > mlx.core.exp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.expm1.html" > mlx.core.expm1< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.expand_dims.html" > mlx.core.expand_dims< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.eye.html" > mlx.core.eye< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.flatten.html" > mlx.core.flatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.floor.html" > mlx.core.floor< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.floor_divide.html" > mlx.core.floor_divide< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.full.html" > mlx.core.full< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.gather_mm.html" > mlx.core.gather_mm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.gather_qmm.html" > mlx.core.gather_qmm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.greater.html" > mlx.core.greater< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.greater_equal.html" > mlx.core.greater_equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.hadamard_transform.html" > mlx.core.hadamard_transform< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.identity.html" > mlx.core.identity< / a > < / li >
2024-10-19 03:13:44 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.imag.html" > mlx.core.imag< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.inner.html" > mlx.core.inner< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isfinite.html" > mlx.core.isfinite< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isclose.html" > mlx.core.isclose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isinf.html" > mlx.core.isinf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isnan.html" > mlx.core.isnan< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isneginf.html" > mlx.core.isneginf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.isposinf.html" > mlx.core.isposinf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.issubdtype.html" > mlx.core.issubdtype< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.kron.html" > mlx.core.kron< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.left_shift.html" > mlx.core.left_shift< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.less.html" > mlx.core.less< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.less_equal.html" > mlx.core.less_equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linspace.html" > mlx.core.linspace< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.load.html" > mlx.core.load< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log.html" > mlx.core.log< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log2.html" > mlx.core.log2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log10.html" > mlx.core.log10< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.log1p.html" > mlx.core.log1p< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logaddexp.html" > mlx.core.logaddexp< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logcumsumexp.html" > mlx.core.logcumsumexp< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_not.html" > mlx.core.logical_not< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_and.html" > mlx.core.logical_and< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logical_or.html" > mlx.core.logical_or< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.logsumexp.html" > mlx.core.logsumexp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.matmul.html" > mlx.core.matmul< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.max.html" > mlx.core.max< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.maximum.html" > mlx.core.maximum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.mean.html" > mlx.core.mean< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.meshgrid.html" > mlx.core.meshgrid< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.min.html" > mlx.core.min< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.minimum.html" > mlx.core.minimum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.moveaxis.html" > mlx.core.moveaxis< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.multiply.html" > mlx.core.multiply< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.nan_to_num.html" > mlx.core.nan_to_num< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.negative.html" > mlx.core.negative< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.not_equal.html" > mlx.core.not_equal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ones.html" > mlx.core.ones< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.ones_like.html" > mlx.core.ones_like< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.outer.html" > mlx.core.outer< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.partition.html" > mlx.core.partition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.pad.html" > mlx.core.pad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.power.html" > mlx.core.power< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.prod.html" > mlx.core.prod< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.put_along_axis.html" > mlx.core.put_along_axis< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.quantize.html" > mlx.core.quantize< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.quantized_matmul.html" > mlx.core.quantized_matmul< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.radians.html" > mlx.core.radians< / a > < / li >
2024-10-19 03:13:44 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.real.html" > mlx.core.real< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reciprocal.html" > mlx.core.reciprocal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.remainder.html" > mlx.core.remainder< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.repeat.html" > mlx.core.repeat< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reshape.html" > mlx.core.reshape< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.right_shift.html" > mlx.core.right_shift< / a > < / li >
2024-10-15 04:10:48 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.roll.html" > mlx.core.roll< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.round.html" > mlx.core.round< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.rsqrt.html" > mlx.core.rsqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save.html" > mlx.core.save< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.savez.html" > mlx.core.savez< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.savez_compressed.html" > mlx.core.savez_compressed< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save_gguf.html" > mlx.core.save_gguf< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.save_safetensors.html" > mlx.core.save_safetensors< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sigmoid.html" > mlx.core.sigmoid< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sign.html" > mlx.core.sign< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sin.html" > mlx.core.sin< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sinh.html" > mlx.core.sinh< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.slice.html" > mlx.core.slice< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.slice_update.html" > mlx.core.slice_update< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.softmax.html" > mlx.core.softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sort.html" > mlx.core.sort< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.split.html" > mlx.core.split< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sqrt.html" > mlx.core.sqrt< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.square.html" > mlx.core.square< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.squeeze.html" > mlx.core.squeeze< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stack.html" > mlx.core.stack< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.std.html" > mlx.core.std< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.stop_gradient.html" > mlx.core.stop_gradient< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.subtract.html" > mlx.core.subtract< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.sum.html" > mlx.core.sum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.swapaxes.html" > mlx.core.swapaxes< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.take.html" > mlx.core.take< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.take_along_axis.html" > mlx.core.take_along_axis< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tan.html" > mlx.core.tan< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tanh.html" > mlx.core.tanh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tensordot.html" > mlx.core.tensordot< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tile.html" > mlx.core.tile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.topk.html" > mlx.core.topk< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.trace.html" > mlx.core.trace< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.transpose.html" > mlx.core.transpose< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tri.html" > mlx.core.tri< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.tril.html" > mlx.core.tril< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.triu.html" > mlx.core.triu< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.unflatten.html" > mlx.core.unflatten< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.var.html" > mlx.core.var< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.view.html" > mlx.core.view< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.where.html" > mlx.core.where< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.zeros.html" > mlx.core.zeros< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.zeros_like.html" > mlx.core.zeros_like< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/random.html" > Random< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.bernoulli.html" > mlx.core.random.bernoulli< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.categorical.html" > mlx.core.random.categorical< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.gumbel.html" > mlx.core.random.gumbel< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.key.html" > mlx.core.random.key< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.normal.html" > mlx.core.random.normal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.multivariate_normal.html" > mlx.core.random.multivariate_normal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.randint.html" > mlx.core.random.randint< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.seed.html" > mlx.core.random.seed< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.split.html" > mlx.core.random.split< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.truncated_normal.html" > mlx.core.random.truncated_normal< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.uniform.html" > mlx.core.random.uniform< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.laplace.html" > mlx.core.random.laplace< / a > < / li >
2024-10-15 04:10:48 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.random.permutation.html" > mlx.core.random.permutation< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/transforms.html" > Transforms< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.eval.html" > mlx.core.eval< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.async_eval.html" > mlx.core.async_eval< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.compile.html" > mlx.core.compile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.custom_function.html" > mlx.core.custom_function< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.disable_compile.html" > mlx.core.disable_compile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.enable_compile.html" > mlx.core.enable_compile< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.grad.html" > mlx.core.grad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.value_and_grad.html" > mlx.core.value_and_grad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.jvp.html" > mlx.core.jvp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.vjp.html" > mlx.core.vjp< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.vmap.html" > mlx.core.vmap< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/fast.html" > Fast< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.rms_norm.html" > mlx.core.fast.rms_norm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.layer_norm.html" > mlx.core.fast.layer_norm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.rope.html" > mlx.core.fast.rope< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html" > mlx.core.fast.scaled_dot_product_attention< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fast.metal_kernel.html" > mlx.core.fast.metal_kernel< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/fft.html" > FFT< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fft.html" > mlx.core.fft.fft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifft.html" > mlx.core.fft.ifft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fft2.html" > mlx.core.fft.fft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifft2.html" > mlx.core.fft.ifft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fftn.html" > mlx.core.fft.fftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifftn.html" > mlx.core.fft.ifftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfft.html" > mlx.core.fft.rfft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfft.html" > mlx.core.fft.irfft< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfft2.html" > mlx.core.fft.rfft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfft2.html" > mlx.core.fft.irfft2< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.rfftn.html" > mlx.core.fft.rfftn< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.irfftn.html" > mlx.core.fft.irfftn< / a > < / li >
2025-05-10 05:42:00 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.fftshift.html" > mlx.core.fft.fftshift< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.fft.ifftshift.html" > mlx.core.fft.ifftshift< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/linalg.html" > Linear Algebra< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.inv.html" > mlx.core.linalg.inv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.tri_inv.html" > mlx.core.linalg.tri_inv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.norm.html" > mlx.core.linalg.norm< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cholesky.html" > mlx.core.linalg.cholesky< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cholesky_inv.html" > mlx.core.linalg.cholesky_inv< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.cross.html" > mlx.core.linalg.cross< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.qr.html" > mlx.core.linalg.qr< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.svd.html" > mlx.core.linalg.svd< / a > < / li >
2025-06-03 07:29:32 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigvals.html" > mlx.core.linalg.eigvals< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eig.html" > mlx.core.linalg.eig< / a > < / li >
2024-10-26 04:23:45 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigvalsh.html" > mlx.core.linalg.eigvalsh< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.eigh.html" > mlx.core.linalg.eigh< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.lu.html" > mlx.core.linalg.lu< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.lu_factor.html" > mlx.core.linalg.lu_factor< / a > < / li >
2025-04-04 04:25:24 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.pinv.html" > mlx.core.linalg.pinv< / a > < / li >
2025-02-15 05:44:39 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.solve.html" > mlx.core.linalg.solve< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.linalg.solve_triangular.html" > mlx.core.linalg.solve_triangular< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/metal.html" > Metal< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.is_available.html" > mlx.core.metal.is_available< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.device_info.html" > mlx.core.metal.device_info< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.start_capture.html" > mlx.core.metal.start_capture< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.metal.stop_capture.html" > mlx.core.metal.stop_capture< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2025-03-25 04:24:41 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/memory_management.html" > Memory Management< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_active_memory.html" > mlx.core.get_active_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_peak_memory.html" > mlx.core.get_peak_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.reset_peak_memory.html" > mlx.core.reset_peak_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.get_cache_memory.html" > mlx.core.get_cache_memory< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_memory_limit.html" > mlx.core.set_memory_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_cache_limit.html" > mlx.core.set_cache_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.set_wired_limit.html" > mlx.core.set_wired_limit< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.clear_cache.html" > mlx.core.clear_cache< / a > < / li >
< / ul >
< / details > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/nn.html" > Neural Networks< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.value_and_grad.html" > mlx.nn.value_and_grad< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.quantize.html" > mlx.nn.quantize< / a > < / li >
2025-03-06 05:30:09 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.nn.average_gradients.html" > mlx.nn.average_gradients< / a > < / li >
2024-10-15 23:12:17 +08:00
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/module.html" > Module< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.training.html" > mlx.nn.Module.training< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.state.html" > mlx.nn.Module.state< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.apply.html" > mlx.nn.Module.apply< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html" > mlx.nn.Module.apply_to_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.children.html" > mlx.nn.Module.children< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.eval.html" > mlx.nn.Module.eval< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.filter_and_map.html" > mlx.nn.Module.filter_and_map< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.freeze.html" > mlx.nn.Module.freeze< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.leaf_modules.html" > mlx.nn.Module.leaf_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.load_weights.html" > mlx.nn.Module.load_weights< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.modules.html" > mlx.nn.Module.modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.named_modules.html" > mlx.nn.Module.named_modules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.parameters.html" > mlx.nn.Module.parameters< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.save_weights.html" > mlx.nn.Module.save_weights< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.set_dtype.html" > mlx.nn.Module.set_dtype< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.train.html" > mlx.nn.Module.train< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.trainable_parameters.html" > mlx.nn.Module.trainable_parameters< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.unfreeze.html" > mlx.nn.Module.unfreeze< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.update.html" > mlx.nn.Module.update< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Module.update_modules.html" > mlx.nn.Module.update_modules< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/layers.html" > Layers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ALiBi.html" > mlx.nn.ALiBi< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool1d.html" > mlx.nn.AvgPool1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool2d.html" > mlx.nn.AvgPool2d< / a > < / li >
2024-11-23 04:24:16 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.AvgPool3d.html" > mlx.nn.AvgPool3d< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.BatchNorm.html" > mlx.nn.BatchNorm< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.CELU.html" > mlx.nn.CELU< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv1d.html" > mlx.nn.Conv1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv2d.html" > mlx.nn.Conv2d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Conv3d.html" > mlx.nn.Conv3d< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose1d.html" > mlx.nn.ConvTranspose1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose2d.html" > mlx.nn.ConvTranspose2d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ConvTranspose3d.html" > mlx.nn.ConvTranspose3d< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout.html" > mlx.nn.Dropout< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout2d.html" > mlx.nn.Dropout2d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Dropout3d.html" > mlx.nn.Dropout3d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Embedding.html" > mlx.nn.Embedding< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ELU.html" > mlx.nn.ELU< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GELU.html" > mlx.nn.GELU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GLU.html" > mlx.nn.GLU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GroupNorm.html" > mlx.nn.GroupNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.GRU.html" > mlx.nn.GRU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.HardShrink.html" > mlx.nn.HardShrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.HardTanh.html" > mlx.nn.HardTanh< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Hardswish.html" > mlx.nn.Hardswish< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.InstanceNorm.html" > mlx.nn.InstanceNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LayerNorm.html" > mlx.nn.LayerNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LeakyReLU.html" > mlx.nn.LeakyReLU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Linear.html" > mlx.nn.Linear< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LogSigmoid.html" > mlx.nn.LogSigmoid< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LogSoftmax.html" > mlx.nn.LogSoftmax< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.LSTM.html" > mlx.nn.LSTM< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool1d.html" > mlx.nn.MaxPool1d< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool2d.html" > mlx.nn.MaxPool2d< / a > < / li >
2024-11-23 04:24:16 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MaxPool3d.html" > mlx.nn.MaxPool3d< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Mish.html" > mlx.nn.Mish< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html" > mlx.nn.MultiHeadAttention< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.PReLU.html" > mlx.nn.PReLU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.QuantizedEmbedding.html" > mlx.nn.QuantizedEmbedding< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.QuantizedLinear.html" > mlx.nn.QuantizedLinear< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RMSNorm.html" > mlx.nn.RMSNorm< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ReLU.html" > mlx.nn.ReLU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.ReLU6.html" > mlx.nn.ReLU6< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RNN.html" > mlx.nn.RNN< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.RoPE.html" > mlx.nn.RoPE< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SELU.html" > mlx.nn.SELU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Sequential.html" > mlx.nn.Sequential< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Sigmoid.html" > mlx.nn.Sigmoid< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SiLU.html" > mlx.nn.SiLU< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html" > mlx.nn.SinusoidalPositionalEncoding< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softmin.html" > mlx.nn.Softmin< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softshrink.html" > mlx.nn.Softshrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softsign.html" > mlx.nn.Softsign< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softmax.html" > mlx.nn.Softmax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Softplus.html" > mlx.nn.Softplus< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Step.html" > mlx.nn.Step< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Tanh.html" > mlx.nn.Tanh< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Transformer.html" > mlx.nn.Transformer< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.Upsample.html" > mlx.nn.Upsample< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/functions.html" > Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.elu.html" > mlx.nn.elu< / a > < / li >
2024-09-29 02:04:59 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.celu.html" > mlx.nn.celu< / a > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu.html" > mlx.nn.gelu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html" > mlx.nn.gelu_approx< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html" > mlx.nn.gelu_fast_approx< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.glu.html" > mlx.nn.glu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hard_shrink.html" > mlx.nn.hard_shrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hard_tanh.html" > mlx.nn.hard_tanh< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.hardswish.html" > mlx.nn.hardswish< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.leaky_relu.html" > mlx.nn.leaky_relu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.log_sigmoid.html" > mlx.nn.log_sigmoid< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.log_softmax.html" > mlx.nn.log_softmax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.mish.html" > mlx.nn.mish< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.prelu.html" > mlx.nn.prelu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.relu.html" > mlx.nn.relu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.relu6.html" > mlx.nn.relu6< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.selu.html" > mlx.nn.selu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.sigmoid.html" > mlx.nn.sigmoid< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.silu.html" > mlx.nn.silu< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softmax.html" > mlx.nn.softmax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softmin.html" > mlx.nn.softmin< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softplus.html" > mlx.nn.softplus< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.softshrink.html" > mlx.nn.softshrink< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.step.html" > mlx.nn.step< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.tanh.html" > mlx.nn.tanh< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/losses.html" > Loss Functions< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html" > mlx.nn.losses.binary_cross_entropy< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html" > mlx.nn.losses.cosine_similarity_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html" > mlx.nn.losses.cross_entropy< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html" > mlx.nn.losses.gaussian_nll_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html" > mlx.nn.losses.hinge_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.html" > mlx.nn.losses.huber_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html" > mlx.nn.losses.kl_div_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html" > mlx.nn.losses.l1_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html" > mlx.nn.losses.log_cosh_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html" > mlx.nn.losses.margin_ranking_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html" > mlx.nn.losses.mse_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html" > mlx.nn.losses.nll_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html" > mlx.nn.losses.smooth_l1_loss< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html" > mlx.nn.losses.triplet_loss< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/nn/init.html" > Initializers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.constant.html" > mlx.nn.init.constant< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.normal.html" > mlx.nn.init.normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.uniform.html" > mlx.nn.init.uniform< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.identity.html" > mlx.nn.init.identity< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.glorot_normal.html" > mlx.nn.init.glorot_normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.glorot_uniform.html" > mlx.nn.init.glorot_uniform< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.he_normal.html" > mlx.nn.init.he_normal< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/nn/_autosummary/mlx.nn.init.he_uniform.html" > mlx.nn.init.he_uniform< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/optimizers.html" > Optimizers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/optimizer.html" > Optimizer< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.html" > mlx.optimizers.Optimizer.state< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html" > mlx.optimizers.Optimizer.apply_gradients< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.html" > mlx.optimizers.Optimizer.init< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.html" > mlx.optimizers.Optimizer.update< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/common_optimizers.html" > Common Optimizers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.SGD.html" > mlx.optimizers.SGD< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.RMSprop.html" > mlx.optimizers.RMSprop< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adagrad.html" > mlx.optimizers.Adagrad< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adafactor.html" > mlx.optimizers.Adafactor< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.AdaDelta.html" > mlx.optimizers.AdaDelta< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adam.html" > mlx.optimizers.Adam< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.AdamW.html" > mlx.optimizers.AdamW< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Adamax.html" > mlx.optimizers.Adamax< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.Lion.html" > mlx.optimizers.Lion< / a > < / li >
2025-04-18 06:29:33 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.MultiOptimizer.html" > mlx.optimizers.MultiOptimizer< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l2 has-children" > < a class = "reference internal" href = "../python/optimizers/schedulers.html" > Schedulers< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.cosine_decay.html" > mlx.optimizers.cosine_decay< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.exponential_decay.html" > mlx.optimizers.exponential_decay< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.join_schedules.html" > mlx.optimizers.join_schedules< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.linear_schedule.html" > mlx.optimizers.linear_schedule< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "../python/optimizers/_autosummary/mlx.optimizers.step_decay.html" > mlx.optimizers.step_decay< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.optimizers.clip_grad_norm.html" > mlx.optimizers.clip_grad_norm< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/distributed.html" > Distributed Communication< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.Group.html" > mlx.core.distributed.Group< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.is_available.html" > mlx.core.distributed.is_available< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.init.html" > mlx.core.distributed.init< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.all_sum.html" > mlx.core.distributed.all_sum< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.all_gather.html" > mlx.core.distributed.all_gather< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.send.html" > mlx.core.distributed.send< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.recv.html" > mlx.core.distributed.recv< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.core.distributed.recv_like.html" > mlx.core.distributed.recv_like< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
< li class = "toctree-l1 has-children" > < a class = "reference internal" href = "../python/tree_utils.html" > Tree Utils< / a > < details > < summary > < span class = "toctree-toggle" role = "presentation" > < i class = "fa-solid fa-chevron-down" > < / i > < / span > < / summary > < ul >
2024-08-24 03:14:53 +08:00
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_flatten.html" > mlx.utils.tree_flatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_unflatten.html" > mlx.utils.tree_unflatten< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_map.html" > mlx.utils.tree_map< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_map_with_path.html" > mlx.utils.tree_map_with_path< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "../python/_autosummary/mlx.utils.tree_reduce.html" > mlx.utils.tree_reduce< / a > < / li >
< / ul >
2024-10-15 23:12:17 +08:00
< / details > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > C++ API Reference< / span > < / p >
< ul class = "nav bd-sidenav" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../cpp/ops.html" > Operations< / a > < / li >
< / ul >
< p aria-level = "2" class = "caption" role = "heading" > < span class = "caption-text" > Further Reading< / span > < / p >
< ul class = "current nav bd-sidenav" >
< li class = "toctree-l1" > < a class = "reference internal" href = "extensions.html" > Custom Extensions in MLX< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "metal_debugger.html" > Metal Debugger< / a > < / li >
< li class = "toctree-l1 current active" > < a class = "current reference internal" href = "#" > Custom Metal Kernels< / a > < / li >
2025-01-10 05:56:20 +08:00
< li class = "toctree-l1" > < a class = "reference internal" href = "mlx_in_cpp.html" > Using MLX in C++< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
< / div >
< / nav > < / div >
< / div >
< div class = "sidebar-primary-items__end sidebar-primary__section" >
< / div >
2025-03-06 05:30:09 +08:00
< div id = "rtd-footer-container" > < / div >
2024-08-24 03:14:53 +08:00
< / div >
2024-10-15 23:12:17 +08:00
< main id = "main-content" class = "bd-main" role = "main" >
2024-08-24 03:14:53 +08:00
< div class = "sbt-scroll-pixel-helper" > < / div >
< div class = "bd-content" >
< div class = "bd-article-container" >
2024-10-15 23:12:17 +08:00
< div class = "bd-header-article d-print-none" >
2024-08-24 03:14:53 +08:00
< div class = "header-article-items header-article__inner" >
< div class = "header-article-items__start" >
2024-10-15 23:12:17 +08:00
< div class = "header-article-item" > < button class = "sidebar-toggle primary-toggle btn btn-sm" title = "Toggle primary sidebar" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2024-08-24 03:14:53 +08:00
< span class = "fa-solid fa-bars" > < / span >
2024-10-15 23:12:17 +08:00
< / button > < / div >
2024-08-24 03:14:53 +08:00
< / div >
< div class = "header-article-items__end" >
< div class = "header-article-item" >
< div class = "article-header-buttons" >
< a href = "https://github.com/ml-explore/mlx" target = "_blank"
class="btn btn-sm btn-source-repository-button"
title="Source repository"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fab fa-github" > < / i >
< / span >
< / a >
< div class = "dropdown dropdown-download-buttons" >
< button class = "btn dropdown-toggle" type = "button" data-bs-toggle = "dropdown" aria-expanded = "false" aria-label = "Download this page" >
< i class = "fas fa-download" > < / i >
< / button >
< ul class = "dropdown-menu" >
< li > < a href = "../_sources/dev/custom_metal_kernels.rst" target = "_blank"
class="btn btn-sm btn-download-source-button dropdown-item"
title="Download source file"
data-bs-placement="left" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-file" > < / i >
< / span >
< span class = "btn__text-container" > .rst< / span >
< / a >
< / li >
< li >
< button onclick = "window.print()"
class="btn btn-sm btn-download-pdf-button dropdown-item"
title="Print to PDF"
data-bs-placement="left" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-file-pdf" > < / i >
< / span >
< span class = "btn__text-container" > .pdf< / span >
< / button >
< / li >
< / ul >
< / div >
< button onclick = "toggleFullScreen()"
class="btn btn-sm btn-fullscreen-button"
title="Fullscreen mode"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
< span class = "btn__icon-container" >
< i class = "fas fa-expand" > < / i >
< / span >
< / button >
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn btn-sm nav-link pst-navbar-icon theme-switch-button" title = "light/dark" aria-label = "light/dark" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
< i class = "theme-switch fa-solid fa-sun fa-lg" data-mode = "light" > < / i >
< i class = "theme-switch fa-solid fa-moon fa-lg" data-mode = "dark" > < / i >
< i class = "theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode = "auto" > < / i >
< / button >
`);
< / script >
2024-08-24 03:14:53 +08:00
2025-03-06 05:30:09 +08:00
< script >
document.write(`
< button class = "btn btn-sm pst-navbar-icon search-button search-button__button" title = "Search" aria-label = "Search" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2024-08-24 03:14:53 +08:00
< i class = "fa-solid fa-magnifying-glass fa-lg" > < / i >
2025-03-06 05:30:09 +08:00
< / button >
`);
< / script >
2024-10-15 23:12:17 +08:00
< button class = "sidebar-toggle secondary-toggle btn btn-sm" title = "Toggle secondary sidebar" data-bs-placement = "bottom" data-bs-toggle = "tooltip" >
2024-08-24 03:14:53 +08:00
< span class = "fa-solid fa-list" > < / span >
2024-10-15 23:12:17 +08:00
< / button >
2024-08-24 03:14:53 +08:00
< / div > < / div >
< / div >
< / div >
< / div >
< div id = "jb-print-docs-body" class = "onlyprint" >
< h1 > Custom Metal Kernels< / h1 >
<!-- Table of contents -->
< div id = "print-main-content" >
< div id = "jb-print-toc" >
< div >
< h2 > Contents < / h2 >
< / div >
< nav aria-label = "Page" >
< ul class = "visible nav section-nav flex-column" >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#simple-example" > Simple Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#using-shape-strides" > Using Shape/Strides< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#complex-example" > Complex Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#grid-sample-vjp" > Grid Sample VJP< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
< / nav >
< / div >
< / div >
< / div >
< div id = "searchbox" > < / div >
2024-10-15 23:12:17 +08:00
< article class = "bd-article" >
2024-08-24 03:14:53 +08:00
< section id = "custom-metal-kernels" >
2024-11-06 03:54:16 +08:00
< span id = "id1" > < / span > < h1 > Custom Metal Kernels< a class = "headerlink" href = "#custom-metal-kernels" title = "Link to this heading" > #< / a > < / h1 >
2024-08-24 03:14:53 +08:00
< p > MLX supports writing custom Metal kernels through the Python and C++ APIs.< / p >
< section id = "simple-example" >
< h2 > Simple Example< a class = "headerlink" href = "#simple-example" title = "Link to this heading" > #< / a > < / h2 >
< p > Let’ s write a custom kernel that computes < code class = "docutils literal notranslate" > < span class = "pre" > exp< / span > < / code > elementwise:< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > exp_elementwise< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ):< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "s2" > " " " < / span >
< span class = "s2" > uint elem = thread_position_in_grid.x;< / span >
< span class = "s2" > T tmp = inp[elem];< / span >
< span class = "s2" > out[elem] = metal::exp(tmp);< / span >
< span class = "s2" > " " " < / span >
< span class = "n" > kernel< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > fast< / span > < span class = "o" > .< / span > < span class = "n" > metal_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s2" > " myexp" < / span > < span class = "p" > ,< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > input_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " inp" < / span > < span class = "p" > ],< / span >
< span class = "n" > output_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " out" < / span > < span class = "p" > ],< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "n" > source< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > outputs< / span > < span class = "o" > =< / span > < span class = "n" > kernel< / span > < span class = "p" > (< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > inputs< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "p" > ],< / span >
< span class = "n" > template< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s2" > " T" < / span > < span class = "p" > ,< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )],< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "n" > threadgroup< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > output_shapes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > ],< / span >
< span class = "n" > output_dtypes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ],< / span >
2024-08-24 03:14:53 +08:00
< span class = "p" > )< / span >
2024-09-18 03:06:14 +08:00
< span class = "k" > return< / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > random< / span > < span class = "o" > .< / span > < span class = "n" > normal< / span > < span class = "p" > (< / span > < span class = "n" > shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span > < span class = "mi" > 16< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > )< / span >
< span class = "n" > b< / span > < span class = "o" > =< / span > < span class = "n" > exp_elementwise< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > allclose< / span > < span class = "p" > (< / span > < span class = "n" > b< / span > < span class = "p" > ,< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ))< / span >
< / pre > < / div >
< / div >
< div class = "admonition note" >
< p class = "admonition-title" > Note< / p >
< p > We are only required to pass the body of the Metal kernel in < code class = "docutils literal notranslate" > < span class = "pre" > source< / span > < / code > .< / p >
< / div >
< p > The full function signature will be generated using:< / p >
< ul class = "simple" >
< li > < dl class = "simple" >
2024-09-18 03:06:14 +08:00
< dt > The shapes/dtypes of < code class = "docutils literal notranslate" > < span class = "pre" > inputs< / span > < / code > < / dt > < dd > < p > In the above, < code class = "docutils literal notranslate" > < span class = "pre" > a< / span > < / code > is an < code class = "docutils literal notranslate" > < span class = "pre" > mx.array< / span > < / code > of type < code class = "docutils literal notranslate" > < span class = "pre" > mx.float16< / span > < / code > and we pass it with the key < code class = "docutils literal notranslate" > < span class = "pre" > inp< / span > < / code >
2024-08-24 03:14:53 +08:00
so we will add < code class = "docutils literal notranslate" > < span class = "pre" > const< / span > < span class = "pre" > device< / span > < span class = "pre" > float16_t*< / span > < span class = "pre" > inp< / span > < / code > to the signature.
2024-09-18 03:06:14 +08:00
< code class = "docutils literal notranslate" > < span class = "pre" > inp_shape< / span > < / code > , < code class = "docutils literal notranslate" > < span class = "pre" > inp_strides< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > inp_ndim< / span > < / code > are also added for convenience if they are present
in < code class = "docutils literal notranslate" > < span class = "pre" > source< / span > < / code > .< / p >
2024-08-24 03:14:53 +08:00
< / dd >
< / dl >
< / li >
< li > < dl class = "simple" >
2024-09-18 03:06:14 +08:00
< dt > The list of < code class = "docutils literal notranslate" > < span class = "pre" > output_dtypes< / span > < / code > < / dt > < dd > < p > In the above, < code class = "docutils literal notranslate" > < span class = "pre" > out< / span > < / code > is an < code class = "docutils literal notranslate" > < span class = "pre" > mx.array< / span > < / code > of type < code class = "docutils literal notranslate" > < span class = "pre" > mx.float16< / span > < / code >
2024-08-24 03:14:53 +08:00
so we add < code class = "docutils literal notranslate" > < span class = "pre" > device< / span > < span class = "pre" > float16_t*< / span > < span class = "pre" > out< / span > < / code > .< / p >
< / dd >
< / dl >
< / li >
< li > < dl class = "simple" >
2024-09-18 03:06:14 +08:00
< dt > Template parameters passed using < code class = "docutils literal notranslate" > < span class = "pre" > template< / span > < / code > < / dt > < dd > < p > In the above, < code class = "docutils literal notranslate" > < span class = "pre" > template=[(" T" ,< / span > < span class = "pre" > mx.float32)]< / span > < / code > adds a template of < code class = "docutils literal notranslate" > < span class = "pre" > template< / span > < span class = "pre" > < typename< / span > < span class = "pre" > T> < / span > < / code > to the function
2024-08-24 03:14:53 +08:00
and instantiates the template with < code class = "docutils literal notranslate" > < span class = "pre" > custom_kernel_myexp_float< float> < / span > < / code > .
Template parameters can be < code class = "docutils literal notranslate" > < span class = "pre" > mx.core.Dtype< / span > < / code > , < code class = "docutils literal notranslate" > < span class = "pre" > int< / span > < / code > or < code class = "docutils literal notranslate" > < span class = "pre" > bool< / span > < / code > .< / p >
< / dd >
< / dl >
< / li >
< li > < dl class = "simple" >
< dt > Metal attributes used in < code class = "docutils literal notranslate" > < span class = "pre" > source< / span > < / code > such as < code class = "docutils literal notranslate" > < span class = "pre" > [[thread_position_in_grid]]< / span > < / code > < / dt > < dd > < p > These will be added as function arguments.
All the attributes defined in Table 5.8 of the < a class = "reference external" href = "https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf" > Metal Shading Language Specification< / a > are supported.< / p >
< / dd >
< / dl >
< / li >
< / ul >
< p > Putting this all together, the generated function signature for < code class = "docutils literal notranslate" > < span class = "pre" > myexp< / span > < / code > is as follows:< / p >
< div class = "highlight-cpp notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > template< / span > < span class = "w" > < / span > < span class = "o" > < < / span > < span class = "k" > typename< / span > < span class = "w" > < / span > < span class = "nc" > T< / span > < span class = "o" > > < / span >
< span class = "p" > [[< / span > < span class = "n" > kernel< / span > < span class = "p" > ]]< / span > < span class = "w" > < / span > < span class = "kt" > void< / span > < span class = "w" > < / span > < span class = "n" > custom_kernel_myexp_float< / span > < span class = "p" > (< / span >
< span class = "w" > < / span > < span class = "k" > const< / span > < span class = "w" > < / span > < span class = "n" > device< / span > < span class = "w" > < / span > < span class = "n" > float16_t< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > inp< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > device< / span > < span class = "w" > < / span > < span class = "n" > float16_t< / span > < span class = "o" > *< / span > < span class = "w" > < / span > < span class = "n" > out< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > buffer< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )]],< / span >
< span class = "w" > < / span > < span class = "n" > uint3< / span > < span class = "w" > < / span > < span class = "n" > thread_position_in_grid< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > thread_position_in_grid< / span > < span class = "p" > ]])< / span > < span class = "w" > < / span > < span class = "p" > {< / span >
< span class = "w" > < / span > < span class = "n" > uint< / span > < span class = "w" > < / span > < span class = "n" > elem< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > thread_position_in_grid< / span > < span class = "p" > .< / span > < span class = "n" > x< / span > < span class = "p" > ;< / span >
< span class = "w" > < / span > < span class = "n" > T< / span > < span class = "w" > < / span > < span class = "n" > tmp< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > inp< / span > < span class = "p" > [< / span > < span class = "n" > elem< / span > < span class = "p" > ];< / span >
< span class = "w" > < / span > < span class = "n" > out< / span > < span class = "p" > [< / span > < span class = "n" > elem< / span > < span class = "p" > ]< / span > < span class = "w" > < / span > < span class = "o" > =< / span > < span class = "w" > < / span > < span class = "n" > metal< / span > < span class = "o" > ::< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > tmp< / span > < span class = "p" > );< / span >
< span class = "p" > }< / span >
< span class = "k" > template< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > host_name< / span > < span class = "p" > (< / span > < span class = "s" > " custom_kernel_myexp_float" < / span > < span class = "p" > )]]< / span > < span class = "w" > < / span > < span class = "p" > [[< / span > < span class = "n" > kernel< / span > < span class = "p" > ]]< / span > < span class = "w" > < / span > < span class = "k" > decltype< / span > < span class = "p" > (< / span > < span class = "n" > custom_kernel_myexp_float< / span > < span class = "o" > < < / span > < span class = "kt" > float< / span > < span class = "o" > > < / span > < span class = "p" > )< / span > < span class = "w" > < / span > < span class = "n" > custom_kernel_myexp_float< / span > < span class = "o" > < < / span > < span class = "kt" > float< / span > < span class = "o" > > < / span > < span class = "p" > ;< / span >
< / pre > < / div >
< / div >
2024-11-06 03:54:16 +08:00
< p > Note: < code class = "docutils literal notranslate" > < span class = "pre" > grid< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > threadgroup< / span > < / code > are parameters to the Metal < a class = "reference external" href = "https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads" > dispatchThreads< / a > function.
This means we will launch < code class = "docutils literal notranslate" > < span class = "pre" > mx.prod(grid)< / span > < / code > threads, subdivided into < code class = "docutils literal notranslate" > < span class = "pre" > threadgroup< / span > < / code > size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.< / p >
2024-09-18 03:06:14 +08:00
< p > Passing < code class = "docutils literal notranslate" > < span class = "pre" > verbose=True< / span > < / code > to < code class = "docutils literal notranslate" > < span class = "pre" > mx.fast.metal_kernel.__call__< / span > < / code > will print the generated code for debugging purposes.< / p >
2024-08-24 03:14:53 +08:00
< / section >
< section id = "using-shape-strides" >
< h2 > Using Shape/Strides< a class = "headerlink" href = "#using-shape-strides" title = "Link to this heading" > #< / a > < / h2 >
< p > < code class = "docutils literal notranslate" > < span class = "pre" > mx.fast.metal_kernel< / span > < / code > supports an argument < code class = "docutils literal notranslate" > < span class = "pre" > ensure_row_contiguous< / span > < / code > which is < code class = "docutils literal notranslate" > < span class = "pre" > True< / span > < / code > by default.
This will copy the < code class = "docutils literal notranslate" > < span class = "pre" > mx.array< / span > < / code > inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don’ t have to worry about gaps or the ordering of the dims
when indexing.< / p >
< p > If we want to avoid this copy, < code class = "docutils literal notranslate" > < span class = "pre" > metal_kernel< / span > < / code > automatically passes < code class = "docutils literal notranslate" > < span class = "pre" > a_shape< / span > < / code > , < code class = "docutils literal notranslate" > < span class = "pre" > a_strides< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > a_ndim< / span > < / code > for each
input array < code class = "docutils literal notranslate" > < span class = "pre" > a< / span > < / code > if any are present in < code class = "docutils literal notranslate" > < span class = "pre" > source< / span > < / code > .
We can then use MLX’ s built in indexing utils to fetch the right elements for each thread.< / p >
< p > Let’ s convert < code class = "docutils literal notranslate" > < span class = "pre" > myexp< / span > < / code > above to support arbitrarily strided arrays without relying on a copy from < code class = "docutils literal notranslate" > < span class = "pre" > ensure_row_contiguous< / span > < / code > :< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > exp_elementwise< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > :< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > array< / span > < span class = "p" > ):< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "s2" > " " " < / span >
< span class = "s2" > uint elem = thread_position_in_grid.x;< / span >
< span class = "s2" > // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included< / span >
< span class = "s2" > uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);< / span >
< span class = "s2" > T tmp = inp[loc];< / span >
< span class = "s2" > // Output arrays are always row contiguous< / span >
< span class = "s2" > out[elem] = metal::exp(tmp);< / span >
< span class = "s2" > " " " < / span >
< span class = "n" > kernel< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > fast< / span > < span class = "o" > .< / span > < span class = "n" > metal_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s2" > " myexp_strided" < / span > < span class = "p" > ,< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > input_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " inp" < / span > < span class = "p" > ],< / span >
< span class = "n" > output_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " out" < / span > < span class = "p" > ],< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "n" > source< / span >
< span class = "p" > )< / span >
< span class = "n" > outputs< / span > < span class = "o" > =< / span > < span class = "n" > kernel< / span > < span class = "p" > (< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > inputs< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "p" > ],< / span >
< span class = "n" > template< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s2" > " T" < / span > < span class = "p" > ,< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )],< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "n" > threadgroup< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > output_shapes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > ],< / span >
< span class = "n" > output_dtypes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ],< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > ensure_row_contiguous< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
2024-09-18 03:06:14 +08:00
< span class = "k" > return< / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
2024-08-24 03:14:53 +08:00
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > random< / span > < span class = "o" > .< / span > < span class = "n" > normal< / span > < span class = "p" > (< / span > < span class = "n" > shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span > < span class = "mi" > 16< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > )< / span >
< span class = "c1" > # make non-contiguous< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > a< / span > < span class = "p" > [::< / span > < span class = "mi" > 2< / span > < span class = "p" > ]< / span >
< span class = "n" > b< / span > < span class = "o" > =< / span > < span class = "n" > exp_elementwise< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > allclose< / span > < span class = "p" > (< / span > < span class = "n" > b< / span > < span class = "p" > ,< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ))< / span >
< / pre > < / div >
< / div >
< / section >
2024-09-18 03:06:14 +08:00
< section id = "complex-example" >
< h2 > Complex Example< a class = "headerlink" href = "#complex-example" title = "Link to this heading" > #< / a > < / h2 >
< p > Let’ s implement a more complex example: < code class = "docutils literal notranslate" > < span class = "pre" > grid_sample< / span > < / code > in < code class = "docutils literal notranslate" > < span class = "pre" > " bilinear" < / span > < / code > mode.< / p >
< p > We’ ll start with the following MLX implementation using standard ops:< / p >
2025-01-10 05:56:20 +08:00
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > grid_sample_ref< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "p" > ):< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > H_in< / span > < span class = "p" > ,< / span > < span class = "n" > W_in< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > ix< / span > < span class = "o" > =< / span > < span class = "p" > ((< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > W_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "mi" > 2< / span >
< span class = "n" > iy< / span > < span class = "o" > =< / span > < span class = "p" > ((< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > H_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "mi" > 2< / span >
< span class = "n" > ix_nw< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > floor< / span > < span class = "p" > (< / span > < span class = "n" > ix< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span >
< span class = "n" > iy_nw< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > floor< / span > < span class = "p" > (< / span > < span class = "n" > iy< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span >
< span class = "n" > ix_ne< / span > < span class = "o" > =< / span > < span class = "n" > ix_nw< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span >
< span class = "n" > iy_ne< / span > < span class = "o" > =< / span > < span class = "n" > iy_nw< / span >
< span class = "n" > ix_sw< / span > < span class = "o" > =< / span > < span class = "n" > ix_nw< / span >
< span class = "n" > iy_sw< / span > < span class = "o" > =< / span > < span class = "n" > iy_nw< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span >
< span class = "n" > ix_se< / span > < span class = "o" > =< / span > < span class = "n" > ix_nw< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span >
< span class = "n" > iy_se< / span > < span class = "o" > =< / span > < span class = "n" > iy_nw< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span >
< span class = "n" > nw< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > ix_se< / span > < span class = "o" > -< / span > < span class = "n" > ix< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > iy_se< / span > < span class = "o" > -< / span > < span class = "n" > iy< / span > < span class = "p" > )< / span >
< span class = "n" > ne< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > ix< / span > < span class = "o" > -< / span > < span class = "n" > ix_sw< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > iy_sw< / span > < span class = "o" > -< / span > < span class = "n" > iy< / span > < span class = "p" > )< / span >
< span class = "n" > sw< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > ix_ne< / span > < span class = "o" > -< / span > < span class = "n" > ix< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > iy< / span > < span class = "o" > -< / span > < span class = "n" > iy_ne< / span > < span class = "p" > )< / span >
< span class = "n" > se< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > ix< / span > < span class = "o" > -< / span > < span class = "n" > ix_nw< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > iy< / span > < span class = "o" > -< / span > < span class = "n" > iy_nw< / span > < span class = "p" > )< / span >
< span class = "n" > I_nw< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > )[:,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ],< / span > < span class = "n" > iy_nw< / span > < span class = "p" > ,< / span > < span class = "n" > ix_nw< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > I_ne< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > )[:,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ],< / span > < span class = "n" > iy_ne< / span > < span class = "p" > ,< / span > < span class = "n" > ix_ne< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > I_sw< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > )[:,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ],< / span > < span class = "n" > iy_sw< / span > < span class = "p" > ,< / span > < span class = "n" > ix_sw< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > I_se< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > )[:,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ],< / span > < span class = "n" > iy_se< / span > < span class = "p" > ,< / span > < span class = "n" > ix_se< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > mask_nw< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > iy_nw< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > iy_nw< / span > < span class = "o" > < =< / span > < span class = "n" > H_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_nw< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_nw< / span > < span class = "o" > < =< / span > < span class = "n" > W_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > mask_ne< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > iy_ne< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > iy_ne< / span > < span class = "o" > < =< / span > < span class = "n" > H_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_ne< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_ne< / span > < span class = "o" > < =< / span > < span class = "n" > W_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > mask_sw< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > iy_sw< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > iy_sw< / span > < span class = "o" > < =< / span > < span class = "n" > H_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_sw< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_sw< / span > < span class = "o" > < =< / span > < span class = "n" > W_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > mask_se< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > iy_se< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > iy_se< / span > < span class = "o" > < =< / span > < span class = "n" > H_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_se< / span > < span class = "o" > > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > ix_se< / span > < span class = "o" > < =< / span > < span class = "n" > W_in< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > I_nw< / span > < span class = "o" > *=< / span > < span class = "n" > mask_nw< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > I_ne< / span > < span class = "o" > *=< / span > < span class = "n" > mask_ne< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > I_sw< / span > < span class = "o" > *=< / span > < span class = "n" > mask_sw< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > I_se< / span > < span class = "o" > *=< / span > < span class = "n" > mask_se< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > nw< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > I_nw< / span > < span class = "o" > +< / span > < span class = "n" > ne< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > I_ne< / span > < span class = "o" > +< / span > < span class = "n" > sw< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > I_sw< / span > < span class = "o" > +< / span > < span class = "n" > se< / span > < span class = "p" > [< / span > < span class = "o" > ...< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > I_se< / span >
< span class = "k" > return< / span > < span class = "n" > output< / span >
< / pre > < / div >
< / div >
< p > Now let’ s use < code class = "docutils literal notranslate" > < span class = "pre" > mx.custom_function< / span > < / code > together with < code class = "docutils literal notranslate" > < span class = "pre" > mx.fast.metal_kernel< / span > < / code >
to write a fast GPU kernel for both the forward and backward passes.< / p >
< p > First we’ ll implement the forward pass as a fused kernel:< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @mx< / span > < span class = "o" > .< / span > < span class = "n" > custom_function< / span >
2025-01-10 05:56:20 +08:00
< span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > grid_sample< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "p" > ):< / span >
2024-09-18 03:06:14 +08:00
< span class = "k" > assert< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > ndim< / span > < span class = "o" > ==< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span > < span class = "s2" > " `x` must be 4D." < / span >
< span class = "k" > assert< / span > < span class = "n" > grid< / span > < span class = "o" > .< / span > < span class = "n" > ndim< / span > < span class = "o" > ==< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span > < span class = "s2" > " `grid` must be 4D." < / span >
< span class = "n" > B< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > C< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > gN< / span > < span class = "p" > ,< / span > < span class = "n" > gM< / span > < span class = "p" > ,< / span > < span class = "n" > D< / span > < span class = "o" > =< / span > < span class = "n" > grid< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > out_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > B< / span > < span class = "p" > ,< / span > < span class = "n" > gN< / span > < span class = "p" > ,< / span > < span class = "n" > gM< / span > < span class = "p" > ,< / span > < span class = "n" > C< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > D< / span > < span class = "o" > ==< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "s2" > " Last dim of `grid` must be size 2." < / span >
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "s2" > " " " < / span >
< span class = "s2" > uint elem = thread_position_in_grid.x;< / span >
< span class = "s2" > int H = x_shape[1];< / span >
< span class = "s2" > int W = x_shape[2];< / span >
< span class = "s2" > int C = x_shape[3];< / span >
< span class = "s2" > int gH = grid_shape[1];< / span >
< span class = "s2" > int gW = grid_shape[2];< / span >
< span class = "s2" > int w_stride = C;< / span >
< span class = "s2" > int h_stride = W * w_stride;< / span >
< span class = "s2" > int b_stride = H * h_stride;< / span >
< span class = "s2" > uint grid_idx = elem / C * 2;< / span >
< span class = "s2" > float ix = ((grid[grid_idx] + 1) * W - 1) / 2;< / span >
< span class = "s2" > float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;< / span >
< span class = "s2" > int ix_nw = floor(ix);< / span >
< span class = "s2" > int iy_nw = floor(iy);< / span >
< span class = "s2" > int ix_ne = ix_nw + 1;< / span >
< span class = "s2" > int iy_ne = iy_nw;< / span >
< span class = "s2" > int ix_sw = ix_nw;< / span >
< span class = "s2" > int iy_sw = iy_nw + 1;< / span >
< span class = "s2" > int ix_se = ix_nw + 1;< / span >
< span class = "s2" > int iy_se = iy_nw + 1;< / span >
< span class = "s2" > T nw = (ix_se - ix) * (iy_se - iy);< / span >
< span class = "s2" > T ne = (ix - ix_sw) * (iy_sw - iy);< / span >
< span class = "s2" > T sw = (ix_ne - ix) * (iy - iy_ne);< / span >
< span class = "s2" > T se = (ix - ix_nw) * (iy - iy_nw);< / span >
< span class = "s2" > int batch_idx = elem / C / gH / gW * b_stride;< / span >
< span class = "s2" > int channel_idx = elem % C;< / span >
< span class = "s2" > int base_idx = batch_idx + channel_idx;< / span >
< span class = "s2" > T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];< / span >
< span class = "s2" > T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];< / span >
< span class = "s2" > T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];< / span >
< span class = "s2" > T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];< / span >
< span class = "s2" > I_nw = iy_nw > = 0 & & iy_nw < = H - 1 & & ix_nw > = 0 & & ix_nw < = W - 1 ? I_nw : 0;< / span >
< span class = "s2" > I_ne = iy_ne > = 0 & & iy_ne < = H - 1 & & ix_ne > = 0 & & ix_ne < = W - 1 ? I_ne : 0;< / span >
< span class = "s2" > I_sw = iy_sw > = 0 & & iy_sw < = H - 1 & & ix_sw > = 0 & & ix_sw < = W - 1 ? I_sw : 0;< / span >
< span class = "s2" > I_se = iy_se > = 0 & & iy_se < = H - 1 & & ix_se > = 0 & & ix_se < = W - 1 ? I_se : 0;< / span >
< span class = "s2" > out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;< / span >
< span class = "s2" > " " " < / span >
< span class = "n" > kernel< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > fast< / span > < span class = "o" > .< / span > < span class = "n" > metal_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s2" > " grid_sample" < / span > < span class = "p" > ,< / span >
< span class = "n" > input_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " x" < / span > < span class = "p" > ,< / span > < span class = "s2" > " grid" < / span > < span class = "p" > ],< / span >
< span class = "n" > output_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " out" < / span > < span class = "p" > ],< / span >
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "n" > source< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > outputs< / span > < span class = "o" > =< / span > < span class = "n" > kernel< / span > < span class = "p" > (< / span >
< span class = "n" > inputs< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "p" > ],< / span >
< span class = "n" > template< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s2" > " T" < / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )],< / span >
< span class = "n" > output_shapes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > out_shape< / span > < span class = "p" > ],< / span >
< span class = "n" > output_dtypes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ],< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > out_shape< / span > < span class = "p" > ),< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "n" > threadgroup< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< / pre > < / div >
< / div >
< p > For a reasonably sized input such as:< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 8< / span > < span class = "p" > ,< / span > < span class = "mi" > 1024< / span > < span class = "p" > ,< / span > < span class = "mi" > 1024< / span > < span class = "p" > ,< / span > < span class = "mi" > 64< / span > < span class = "p" > )< / span >
< span class = "n" > grid< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 8< / span > < span class = "p" > ,< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 2< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p > On an M1 Max, we see a big performance improvement:< / p >
< p > < code class = "docutils literal notranslate" > < span class = "pre" > 55.7ms< / span > < span class = "pre" > -> < / span > < span class = "pre" > 6.7ms< / span > < span class = "pre" > => < / span > < span class = "pre" > 8x< / span > < span class = "pre" > speed< / span > < span class = "pre" > up< / span > < / code > < / p >
< / section >
< section id = "grid-sample-vjp" >
< h2 > Grid Sample VJP< a class = "headerlink" href = "#grid-sample-vjp" title = "Link to this heading" > #< / a > < / h2 >
< p > Since we decorated < code class = "docutils literal notranslate" > < span class = "pre" > grid_sample< / span > < / code > with < code class = "docutils literal notranslate" > < span class = "pre" > mx.custom_function< / span > < / code > , we can now define
its custom vjp transform so MLX can differentiate it.< / p >
< p > The backwards pass requires atomically updating < code class = "docutils literal notranslate" > < span class = "pre" > x_grad< / span > < / code > /< code class = "docutils literal notranslate" > < span class = "pre" > grid_grad< / span > < / code > and so
requires a few extra < code class = "docutils literal notranslate" > < span class = "pre" > mx.fast.metal_kernel< / span > < / code > features:< / p >
< ul class = "simple" >
< li > < dl class = "simple" >
< dt > < code class = "docutils literal notranslate" > < span class = "pre" > init_value=0< / span > < / code > < / dt > < dd > < p > Initialize all of the kernel’ s outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.< / p >
< / dd >
< / dl >
< / li >
< li > < dl class = "simple" >
< dt > < code class = "docutils literal notranslate" > < span class = "pre" > atomic_outputs=True< / span > < / code > < / dt > < dd > < p > Designate all of the kernel outputs as < code class = "docutils literal notranslate" > < span class = "pre" > atomic< / span > < / code > in the function signature.
This means we can use Metal’ s < code class = "docutils literal notranslate" > < span class = "pre" > atomic< / span > < / code > features to simultaneously update the < code class = "docutils literal notranslate" > < span class = "pre" > x_grad< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > grid_grad< / span > < / code > arrays from multiple threadgroups.
See section 6.15 of the < a class = "reference external" href = "https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf" > Metal Shading Language Specification< / a > for more details.< / p >
< / dd >
< / dl >
< / li >
< / ul >
< p > We can then implement the backwards pass as follows:< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @grid_sample< / span > < span class = "o" > .< / span > < span class = "n" > vjp< / span >
2025-01-10 05:56:20 +08:00
< span class = "k" > def< / span > < span class = "w" > < / span > < span class = "nf" > grid_sample_vjp< / span > < span class = "p" > (< / span > < span class = "n" > primals< / span > < span class = "p" > ,< / span > < span class = "n" > cotangent< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "p" > ):< / span >
2024-09-18 03:06:14 +08:00
< span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "n" > primals< / span >
< span class = "n" > B< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > C< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > gN< / span > < span class = "p" > ,< / span > < span class = "n" > gM< / span > < span class = "p" > ,< / span > < span class = "n" > D< / span > < span class = "o" > =< / span > < span class = "n" > grid< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "k" > assert< / span > < span class = "n" > D< / span > < span class = "o" > ==< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "s2" > " Last dim of `grid` must be size 2." < / span >
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "s2" > " " " < / span >
< span class = "s2" > uint elem = thread_position_in_grid.x;< / span >
< span class = "s2" > int H = x_shape[1];< / span >
< span class = "s2" > int W = x_shape[2];< / span >
< span class = "s2" > int C = x_shape[3];< / span >
< span class = "s2" > // Pad C to the nearest larger simdgroup size multiple< / span >
< span class = "s2" > int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;< / span >
< span class = "s2" > int gH = grid_shape[1];< / span >
< span class = "s2" > int gW = grid_shape[2];< / span >
< span class = "s2" > int w_stride = C;< / span >
< span class = "s2" > int h_stride = W * w_stride;< / span >
< span class = "s2" > int b_stride = H * h_stride;< / span >
< span class = "s2" > uint grid_idx = elem / C_padded * 2;< / span >
< span class = "s2" > float ix = ((grid[grid_idx] + 1) * W - 1) / 2;< / span >
< span class = "s2" > float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;< / span >
< span class = "s2" > int ix_nw = floor(ix);< / span >
< span class = "s2" > int iy_nw = floor(iy);< / span >
< span class = "s2" > int ix_ne = ix_nw + 1;< / span >
< span class = "s2" > int iy_ne = iy_nw;< / span >
< span class = "s2" > int ix_sw = ix_nw;< / span >
< span class = "s2" > int iy_sw = iy_nw + 1;< / span >
< span class = "s2" > int ix_se = ix_nw + 1;< / span >
< span class = "s2" > int iy_se = iy_nw + 1;< / span >
< span class = "s2" > T nw = (ix_se - ix) * (iy_se - iy);< / span >
< span class = "s2" > T ne = (ix - ix_sw) * (iy_sw - iy);< / span >
< span class = "s2" > T sw = (ix_ne - ix) * (iy - iy_ne);< / span >
< span class = "s2" > T se = (ix - ix_nw) * (iy - iy_nw);< / span >
< span class = "s2" > int batch_idx = elem / C_padded / gH / gW * b_stride;< / span >
< span class = "s2" > int channel_idx = elem % C_padded;< / span >
< span class = "s2" > int base_idx = batch_idx + channel_idx;< / span >
< span class = "s2" > T gix = T(0);< / span >
< span class = "s2" > T giy = T(0);< / span >
< span class = "s2" > if (channel_idx < C) {< / span >
< span class = "s2" > int cot_index = elem / C_padded * C + channel_idx;< / span >
< span class = "s2" > T cot = cotangent[cot_index];< / span >
< span class = "s2" > if (iy_nw > = 0 & & iy_nw < = H - 1 & & ix_nw > = 0 & & ix_nw < = W - 1) {< / span >
< span class = "s2" > int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;< / span >
< span class = "s2" > atomic_fetch_add_explicit(& x_grad[offset], nw * cot, memory_order_relaxed);< / span >
< span class = "s2" > T I_nw = x[offset];< / span >
< span class = "s2" > gix -= I_nw * (iy_se - iy) * cot;< / span >
< span class = "s2" > giy -= I_nw * (ix_se - ix) * cot;< / span >
< span class = "s2" > }< / span >
< span class = "s2" > if (iy_ne > = 0 & & iy_ne < = H - 1 & & ix_ne > = 0 & & ix_ne < = W - 1) {< / span >
< span class = "s2" > int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;< / span >
< span class = "s2" > atomic_fetch_add_explicit(& x_grad[offset], ne * cot, memory_order_relaxed);< / span >
< span class = "s2" > T I_ne = x[offset];< / span >
< span class = "s2" > gix += I_ne * (iy_sw - iy) * cot;< / span >
< span class = "s2" > giy -= I_ne * (ix - ix_sw) * cot;< / span >
< span class = "s2" > }< / span >
< span class = "s2" > if (iy_sw > = 0 & & iy_sw < = H - 1 & & ix_sw > = 0 & & ix_sw < = W - 1) {< / span >
< span class = "s2" > int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;< / span >
< span class = "s2" > atomic_fetch_add_explicit(& x_grad[offset], sw * cot, memory_order_relaxed);< / span >
< span class = "s2" > T I_sw = x[offset];< / span >
< span class = "s2" > gix -= I_sw * (iy - iy_ne) * cot;< / span >
< span class = "s2" > giy += I_sw * (ix_ne - ix) * cot;< / span >
< span class = "s2" > }< / span >
< span class = "s2" > if (iy_se > = 0 & & iy_se < = H - 1 & & ix_se > = 0 & & ix_se < = W - 1) {< / span >
< span class = "s2" > int offset = base_idx + iy_se * h_stride + ix_se * w_stride;< / span >
< span class = "s2" > atomic_fetch_add_explicit(& x_grad[offset], se * cot, memory_order_relaxed);< / span >
< span class = "s2" > T I_se = x[offset];< / span >
< span class = "s2" > gix += I_se * (iy - iy_nw) * cot;< / span >
< span class = "s2" > giy += I_se * (ix - ix_nw) * cot;< / span >
< span class = "s2" > }< / span >
< span class = "s2" > }< / span >
< span class = "s2" > T gix_mult = W / 2;< / span >
< span class = "s2" > T giy_mult = H / 2;< / span >
< span class = "s2" > // Reduce across each simdgroup first.< / span >
< span class = "s2" > // This is much faster than relying purely on atomics.< / span >
< span class = "s2" > gix = simd_sum(gix);< / span >
< span class = "s2" > giy = simd_sum(giy);< / span >
< span class = "s2" > if (thread_index_in_simdgroup == 0) {< / span >
< span class = "s2" > atomic_fetch_add_explicit(& grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);< / span >
< span class = "s2" > atomic_fetch_add_explicit(& grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);< / span >
< span class = "s2" > }< / span >
< span class = "s2" > " " " < / span >
< span class = "n" > kernel< / span > < span class = "o" > =< / span > < span class = "n" > mx< / span > < span class = "o" > .< / span > < span class = "n" > fast< / span > < span class = "o" > .< / span > < span class = "n" > metal_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s2" > " grid_sample_grad" < / span > < span class = "p" > ,< / span >
< span class = "n" > input_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " x" < / span > < span class = "p" > ,< / span > < span class = "s2" > " grid" < / span > < span class = "p" > ,< / span > < span class = "s2" > " cotangent" < / span > < span class = "p" > ],< / span >
< span class = "n" > output_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " x_grad" < / span > < span class = "p" > ,< / span > < span class = "s2" > " grid_grad" < / span > < span class = "p" > ],< / span >
< span class = "n" > source< / span > < span class = "o" > =< / span > < span class = "n" > source< / span > < span class = "p" > ,< / span >
< span class = "n" > atomic_outputs< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "c1" > # pad the output channels to simd group size< / span >
< span class = "c1" > # so that our `simd_sum`s don' t overlap.< / span >
< span class = "n" > simdgroup_size< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span >
< span class = "n" > C_padded< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > C< / span > < span class = "o" > +< / span > < span class = "n" > simdgroup_size< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > //< / span > < span class = "n" > simdgroup_size< / span > < span class = "o" > *< / span > < span class = "n" > simdgroup_size< / span >
< span class = "n" > grid_size< / span > < span class = "o" > =< / span > < span class = "n" > B< / span > < span class = "o" > *< / span > < span class = "n" > gN< / span > < span class = "o" > *< / span > < span class = "n" > gM< / span > < span class = "o" > *< / span > < span class = "n" > C_padded< / span >
< span class = "n" > outputs< / span > < span class = "o" > =< / span > < span class = "n" > kernel< / span > < span class = "p" > (< / span >
< span class = "n" > inputs< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "p" > ,< / span > < span class = "n" > cotangent< / span > < span class = "p" > ],< / span >
< span class = "n" > template< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s2" > " T" < / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )],< / span >
< span class = "n" > output_shapes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > ],< / span >
< span class = "n" > output_dtypes< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ],< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > grid_size< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "n" > threadgroup< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span >
< span class = "n" > init_value< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "n" > outputs< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span >
< / pre > < / div >
< / div >
< p > There’ s an even larger speed up for the vjp:< / p >
< p > < code class = "docutils literal notranslate" > < span class = "pre" > 676.4ms< / span > < span class = "pre" > -> < / span > < span class = "pre" > 16.7ms< / span > < span class = "pre" > => < / span > < span class = "pre" > 40x< / span > < span class = "pre" > speed< / span > < span class = "pre" > up< / span > < / code > < / p >
< / section >
2024-08-24 03:14:53 +08:00
< / section >
< / article >
2024-10-15 23:12:17 +08:00
< footer class = "prev-next-footer d-print-none" >
2024-08-24 03:14:53 +08:00
< div class = "prev-next-area" >
< a class = "left-prev"
href="metal_debugger.html"
title="previous page">
< i class = "fa-solid fa-angle-left" > < / i >
< div class = "prev-next-info" >
< p class = "prev-next-subtitle" > previous< / p >
< p class = "prev-next-title" > Metal Debugger< / p >
< / div >
< / a >
2025-01-10 05:56:20 +08:00
< a class = "right-next"
href="mlx_in_cpp.html"
title="next page">
< div class = "prev-next-info" >
< p class = "prev-next-subtitle" > next< / p >
< p class = "prev-next-title" > Using MLX in C++< / p >
< / div >
< i class = "fa-solid fa-angle-right" > < / i >
< / a >
2024-08-24 03:14:53 +08:00
< / div >
< / footer >
< / div >
2025-03-06 05:30:09 +08:00
< div class = "bd-sidebar-secondary bd-toc" > < div class = "sidebar-secondary-items sidebar-secondary__inner" >
2024-08-24 03:14:53 +08:00
2024-10-15 23:12:17 +08:00
2024-08-24 03:14:53 +08:00
< div class = "sidebar-secondary-item" >
< div class = "page-toc tocsection onthispage" >
< i class = "fa-solid fa-list" > < / i > Contents
< / div >
< nav class = "bd-toc-nav page-toc" >
< ul class = "visible nav section-nav flex-column" >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#simple-example" > Simple Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#using-shape-strides" > Using Shape/Strides< / a > < / li >
2024-09-18 03:06:14 +08:00
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#complex-example" > Complex Example< / a > < / li >
< li class = "toc-h2 nav-item toc-entry" > < a class = "reference internal nav-link" href = "#grid-sample-vjp" > Grid Sample VJP< / a > < / li >
2024-08-24 03:14:53 +08:00
< / ul >
< / nav > < / div >
< / div > < / div >
< / div >
< footer class = "bd-footer-content" >
< div class = "bd-footer-content__inner container" >
< div class = "footer-item" >
< p class = "component-author" >
By MLX Contributors
< / p >
< / div >
< div class = "footer-item" >
< p class = "copyright" >
2025-06-03 07:29:32 +08:00
© Copyright 2023, Apple.
2024-08-24 03:14:53 +08:00
< br / >
< / p >
< / div >
< div class = "footer-item" >
< / div >
< div class = "footer-item" >
< / div >
< / div >
< / footer >
< / main >
< / div >
< / div >
<!-- Scripts loaded after <body> so the DOM is not blocked -->
2025-03-06 05:30:09 +08:00
< script src = "../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" > < / script >
< script src = "../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" > < / script >
2024-08-24 03:14:53 +08:00
< footer class = "bd-footer" >
< / footer >
< / body >
< / html >