2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								<!DOCTYPE html> 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< html  lang = "en"  data-content_root = ""  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < meta  charset = "utf-8"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < meta  name = "viewport"  content = "width=device-width, initial-scale=1.0"  / > < meta  name = "generator"  content = "Docutils 0.18.1: http://docutils.sourceforge.net/"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-13 14:46:24 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < title > Developer Documentation —  MLX 0.0.5 documentation< / title > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < script  data-cfasync = "false" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    document.documentElement.dataset.theme = localStorage.getItem("theme") || "light";
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / script > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <!--  Loaded before other Sphinx assets  --> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  href = "../_static/styles/theme.css?digest=5b4479735964841361fd"  rel = "stylesheet"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< link  href = "../_static/styles/bootstrap.css?digest=5b4479735964841361fd"  rel = "stylesheet"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< link  href = "../_static/styles/pydata-sphinx-theme.css?digest=5b4479735964841361fd"  rel = "stylesheet"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  href = "../_static/vendor/fontawesome/6.1.2/css/all.min.css?digest=5b4479735964841361fd"  rel = "stylesheet"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "preload"  as = "font"  type = "font/woff2"  crossorigin  href = "../_static/vendor/fontawesome/6.1.2/webfonts/fa-solid-900.woff2"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< link  rel = "preload"  as = "font"  type = "font/woff2"  crossorigin  href = "../_static/vendor/fontawesome/6.1.2/webfonts/fa-brands-400.woff2"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< link  rel = "preload"  as = "font"  type = "font/woff2"  crossorigin  href = "../_static/vendor/fontawesome/6.1.2/webfonts/fa-regular-400.woff2"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "stylesheet"  type = "text/css"  href = "../_static/pygments.css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "stylesheet"  href = "../_static/styles/sphinx-book-theme.css?digest=14f4ca6b54d191a8c7657f6c759bf11a5fb86285"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <!--  Pre - loaded scripts that we'll load fully later  --> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "preload"  as = "script"  href = "../_static/scripts/bootstrap.js?digest=5b4479735964841361fd"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< link  rel = "preload"  as = "script"  href = "../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < script  src = "../_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  data-url_root = "../"  id = "documentation_options"  src = "../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../_static/jquery.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../_static/underscore.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../_static/_sphinx_javascript_frameworks_compat.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../_static/doctools.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script > DOCUMENTATION _OPTIONS . pagename  =  'dev/extensions' ; < / script > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "index"  title = "Index"  href = "../genindex.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "search"  title = "Search"  href = "../search.html"  / > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "prev"  title = "Operations"  href = "../cpp/ops.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  name = "viewport"  content = "width=device-width, initial-scale=1" / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  name = "docsearch:language"  content = "en" / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < body  data-bs-spy = "scroll"  data-bs-target = ".bd-toc-nav"  data-offset = "180"  data-bs-root-margin = "0px 0px -60%"  data-default-mode = "" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < a  class = "skip-link"  href = "#main-content" > Skip to main content< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  id = "pst-scroll-pixel-helper" > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < button  type = "button"  class = "btn rounded-pill"  id = "pst-back-to-top" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < i  class = "fa-solid fa-arrow-up" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    Back to top
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < input  type = "checkbox" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          class="sidebar-toggle"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          name="__primary"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          id="__primary"/>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < label  class = "overlay overlay-primary"  for = "__primary" > < / label > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < input  type = "checkbox" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          class="sidebar-toggle"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          name="__secondary"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          id="__secondary"/>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < label  class = "overlay overlay-secondary"  for = "__secondary" > < / label > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "search-button__wrapper" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "search-button__overlay" > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "search-button__search-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< form  class = "bd-search d-flex align-items-center" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      action="../search.html"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      method="get">
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < i  class = "fa-solid fa-magnifying-glass" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < input  type = "search" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         class="form-control"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         name="q"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         id="search-input"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         placeholder="Search..."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         aria-label="Search..."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         autocomplete="off"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         autocorrect="off"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         autocapitalize="off"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								         spellcheck="false"/>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < span  class = "search-button__kbd-shortcut" > < kbd  class = "kbd-shortcut__modifier" > Ctrl< / kbd > +< kbd > K< / kbd > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / form > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < nav  class = "bd-header navbar navbar-expand-lg bd-navbar" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "bd-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "bd-container__inner bd-page-width" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "bd-sidebar-primary bd-sidebar" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "sidebar-header-items sidebar-primary__section" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "sidebar-primary-items__start sidebar-primary__section" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "sidebar-primary-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< a  class = "navbar-brand logo"  href = "../index.html" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
									
										
										
										
											2023-12-13 14:46:24 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < img  src = "../_static/mlx_logo.png"  class = "logo__image only-light"  alt = "MLX 0.0.5 documentation - Home" / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script > document . write ( ` <img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.0.5 documentation - Home"/> ` ) ; < / script > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / a > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "sidebar-primary-item" > < nav  class = "bd-links"  id = "bd-docs-nav"  aria-label = "Main" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "bd-toc-item navbar-nav active" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > Install< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "nav bd-sidenav" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../install.html" > Build and Install< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > Usage< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "nav bd-sidenav" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../quick_start.html" > Quick Start Guide< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-13 14:46:24 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../unified_memory.html" > Unified Memory< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../using_streams.html" > Using Streams< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > Examples< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "nav bd-sidenav" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../examples/linear_regression.html" > Linear Regression< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../examples/mlp.html" > Multi-Layer Perceptron< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../examples/llama-inference.html" > LLM inference< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > Python API Reference< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "nav bd-sidenav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/array.html" > Array< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-1"  name = "toctree-checkbox-1"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-1" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.html" > mlx.core.array< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.astype.html" > mlx.core.array.astype< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.item.html" > mlx.core.array.item< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.tolist.html" > mlx.core.array.tolist< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.dtype.html" > mlx.core.array.dtype< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.ndim.html" > mlx.core.array.ndim< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.shape.html" > mlx.core.array.shape< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.size.html" > mlx.core.array.size< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.Dtype.html" > mlx.core.Dtype< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.abs.html" > mlx.core.array.abs< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.all.html" > mlx.core.array.all< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.any.html" > mlx.core.array.any< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.argmax.html" > mlx.core.array.argmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.argmin.html" > mlx.core.array.argmin< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.cos.html" > mlx.core.array.cos< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.dtype.html" > mlx.core.array.dtype< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.exp.html" > mlx.core.array.exp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.log.html" > mlx.core.array.log< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.log1p.html" > mlx.core.array.log1p< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.logsumexp.html" > mlx.core.array.logsumexp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.max.html" > mlx.core.array.max< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.mean.html" > mlx.core.array.mean< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.min.html" > mlx.core.array.min< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.prod.html" > mlx.core.array.prod< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.reciprocal.html" > mlx.core.array.reciprocal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.reshape.html" > mlx.core.array.reshape< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.rsqrt.html" > mlx.core.array.rsqrt< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.sin.html" > mlx.core.array.sin< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.split.html" > mlx.core.array.split< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.sqrt.html" > mlx.core.array.sqrt< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.square.html" > mlx.core.array.square< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.sum.html" > mlx.core.array.sum< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.transpose.html" > mlx.core.array.transpose< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.T.html" > mlx.core.array.T< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.var.html" > mlx.core.array.var< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/devices_and_streams.html" > Devices and Streams< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-2"  name = "toctree-checkbox-2"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-2" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.Device.html" > mlx.core.Device< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.default_device.html" > mlx.core.default_device< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.set_default_device.html" > mlx.core.set_default_device< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.Stream.html" > mlx.core.Stream< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.default_stream.html" > mlx.core.default_stream< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.new_stream.html" > mlx.core.new_stream< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.set_default_stream.html" > mlx.core.set_default_stream< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/ops.html" > Operations< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-3"  name = "toctree-checkbox-3"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-3" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.abs.html" > mlx.core.abs< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.add.html" > mlx.core.add< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.all.html" > mlx.core.all< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.allclose.html" > mlx.core.allclose< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.any.html" > mlx.core.any< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arange.html" > mlx.core.arange< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arccos.html" > mlx.core.arccos< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arccosh.html" > mlx.core.arccosh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arcsin.html" > mlx.core.arcsin< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arcsinh.html" > mlx.core.arcsinh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arctan.html" > mlx.core.arctan< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.arctanh.html" > mlx.core.arctanh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.argmax.html" > mlx.core.argmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.argmin.html" > mlx.core.argmin< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.argpartition.html" > mlx.core.argpartition< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.argsort.html" > mlx.core.argsort< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array_equal.html" > mlx.core.array_equal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.broadcast_to.html" > mlx.core.broadcast_to< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.ceil.html" > mlx.core.ceil< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.concatenate.html" > mlx.core.concatenate< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.convolve.html" > mlx.core.convolve< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.conv1d.html" > mlx.core.conv1d< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.conv2d.html" > mlx.core.conv2d< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.cos.html" > mlx.core.cos< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.cosh.html" > mlx.core.cosh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.divide.html" > mlx.core.divide< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.equal.html" > mlx.core.equal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.erf.html" > mlx.core.erf< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.erfinv.html" > mlx.core.erfinv< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.exp.html" > mlx.core.exp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.expand_dims.html" > mlx.core.expand_dims< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-13 14:46:24 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.eye.html" > mlx.core.eye< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.floor.html" > mlx.core.floor< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.flatten.html" > mlx.core.flatten< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.full.html" > mlx.core.full< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.greater.html" > mlx.core.greater< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.greater_equal.html" > mlx.core.greater_equal< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-13 14:46:24 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.identity.html" > mlx.core.identity< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.less.html" > mlx.core.less< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.less_equal.html" > mlx.core.less_equal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.load.html" > mlx.core.load< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.log.html" > mlx.core.log< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.log2.html" > mlx.core.log2< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.log10.html" > mlx.core.log10< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.log1p.html" > mlx.core.log1p< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.logaddexp.html" > mlx.core.logaddexp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.logical_not.html" > mlx.core.logical_not< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.logsumexp.html" > mlx.core.logsumexp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.matmul.html" > mlx.core.matmul< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.max.html" > mlx.core.max< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.maximum.html" > mlx.core.maximum< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.mean.html" > mlx.core.mean< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.min.html" > mlx.core.min< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.minimum.html" > mlx.core.minimum< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.moveaxis.html" > mlx.core.moveaxis< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.multiply.html" > mlx.core.multiply< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.negative.html" > mlx.core.negative< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.ones.html" > mlx.core.ones< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.ones_like.html" > mlx.core.ones_like< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.partition.html" > mlx.core.partition< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.pad.html" > mlx.core.pad< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.prod.html" > mlx.core.prod< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.reciprocal.html" > mlx.core.reciprocal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.reshape.html" > mlx.core.reshape< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.rsqrt.html" > mlx.core.rsqrt< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.save.html" > mlx.core.save< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.savez.html" > mlx.core.savez< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.savez_compressed.html" > mlx.core.savez_compressed< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sigmoid.html" > mlx.core.sigmoid< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sign.html" > mlx.core.sign< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sin.html" > mlx.core.sin< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sinh.html" > mlx.core.sinh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.softmax.html" > mlx.core.softmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sort.html" > mlx.core.sort< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.split.html" > mlx.core.split< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sqrt.html" > mlx.core.sqrt< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.square.html" > mlx.core.square< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.squeeze.html" > mlx.core.squeeze< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.stack.html" > mlx.core.stack< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.stop_gradient.html" > mlx.core.stop_gradient< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.subtract.html" > mlx.core.subtract< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.sum.html" > mlx.core.sum< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.swapaxes.html" > mlx.core.swapaxes< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.take.html" > mlx.core.take< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.take_along_axis.html" > mlx.core.take_along_axis< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.tan.html" > mlx.core.tan< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.tanh.html" > mlx.core.tanh< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.transpose.html" > mlx.core.transpose< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.tri.html" > mlx.core.tri< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.tril.html" > mlx.core.tril< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.triu.html" > mlx.core.triu< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.var.html" > mlx.core.var< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.where.html" > mlx.core.where< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.zeros.html" > mlx.core.zeros< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.zeros_like.html" > mlx.core.zeros_like< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/random.html" > Random< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-4"  name = "toctree-checkbox-4"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-4" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.seed.html" > mlx.core.random.seed< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.key.html" > mlx.core.random.key< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.split.html" > mlx.core.random.split< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.bernoulli.html" > mlx.core.random.bernoulli< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.categorical.html" > mlx.core.random.categorical< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.gumbel.html" > mlx.core.random.gumbel< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.normal.html" > mlx.core.random.normal< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.randint.html" > mlx.core.random.randint< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.uniform.html" > mlx.core.random.uniform< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.random.truncated_normal.html" > mlx.core.random.truncated_normal< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/transforms.html" > Transforms< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-5"  name = "toctree-checkbox-5"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-5" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.eval.html" > mlx.core.eval< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.grad.html" > mlx.core.grad< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.value_and_grad.html" > mlx.core.value_and_grad< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.jvp.html" > mlx.core.jvp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.vjp.html" > mlx.core.vjp< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.vmap.html" > mlx.core.vmap< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.simplify.html" > mlx.core.simplify< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/fft.html" > FFT< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-6"  name = "toctree-checkbox-6"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-6" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.fft.html" > mlx.core.fft.fft< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.ifft.html" > mlx.core.fft.ifft< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.fft2.html" > mlx.core.fft.fft2< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.ifft2.html" > mlx.core.fft.ifft2< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.fftn.html" > mlx.core.fft.fftn< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.ifftn.html" > mlx.core.fft.ifftn< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.rfft.html" > mlx.core.fft.rfft< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.irfft.html" > mlx.core.fft.irfft< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.rfft2.html" > mlx.core.fft.rfft2< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.irfft2.html" > mlx.core.fft.irfft2< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.rfftn.html" > mlx.core.fft.rfftn< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.core.fft.irfftn.html" > mlx.core.fft.irfftn< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/nn.html" > Neural Networks< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-7"  name = "toctree-checkbox-7"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-7" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.nn.value_and_grad.html" > mlx.nn.value_and_grad< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.nn.Module.html" > mlx.nn.Module< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2 has-children" > < a  class = "reference internal"  href = "../python/nn/layers.html" > Layers< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-8"  name = "toctree-checkbox-8"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-8" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Embedding.html" > mlx.nn.Embedding< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.ReLU.html" > mlx.nn.ReLU< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.PReLU.html" > mlx.nn.PReLU< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.GELU.html" > mlx.nn.GELU< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.SiLU.html" > mlx.nn.SiLU< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Step.html" > mlx.nn.Step< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.SELU.html" > mlx.nn.SELU< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Mish.html" > mlx.nn.Mish< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Linear.html" > mlx.nn.Linear< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Conv1d.html" > mlx.nn.Conv1d< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Conv2d.html" > mlx.nn.Conv2d< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.LayerNorm.html" > mlx.nn.LayerNorm< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.RMSNorm.html" > mlx.nn.RMSNorm< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.GroupNorm.html" > mlx.nn.GroupNorm< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.RoPE.html" > mlx.nn.RoPE< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.MultiHeadAttention.html" > mlx.nn.MultiHeadAttention< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary/mlx.nn.Sequential.html" > mlx.nn.Sequential< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2 has-children" > < a  class = "reference internal"  href = "../python/nn/functions.html" > Functions< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-9"  name = "toctree-checkbox-9"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-9" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.gelu.html" > mlx.nn.gelu< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.gelu_approx.html" > mlx.nn.gelu_approx< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html" > mlx.nn.gelu_fast_approx< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.relu.html" > mlx.nn.relu< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.prelu.html" > mlx.nn.prelu< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.silu.html" > mlx.nn.silu< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.step.html" > mlx.nn.step< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.selu.html" > mlx.nn.selu< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.mish.html" > mlx.nn.mish< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2 has-children" > < a  class = "reference internal"  href = "../python/nn/losses.html" > Loss Functions< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-10"  name = "toctree-checkbox-10"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-10" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html" > mlx.nn.losses.cross_entropy< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html" > mlx.nn.losses.binary_cross_entropy< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html" > mlx.nn.losses.l1_loss< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html" > mlx.nn.losses.mse_loss< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html" > mlx.nn.losses.nll_loss< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html" > mlx.nn.losses.kl_div_loss< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/optimizers.html" > Optimizers< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-11"  name = "toctree-checkbox-11"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-11" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.OptimizerState.html" > mlx.optimizers.OptimizerState< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.Optimizer.html" > mlx.optimizers.Optimizer< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.SGD.html" > mlx.optimizers.SGD< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.RMSprop.html" > mlx.optimizers.RMSprop< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.Adagrad.html" > mlx.optimizers.Adagrad< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.AdaDelta.html" > mlx.optimizers.AdaDelta< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.Adam.html" > mlx.optimizers.Adam< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.AdamW.html" > mlx.optimizers.AdamW< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.optimizers.Adamax.html" > mlx.optimizers.Adamax< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 has-children" > < a  class = "reference internal"  href = "../python/tree_utils.html" > Tree Utils< / a > < input  class = "toctree-checkbox"  id = "toctree-checkbox-12"  name = "toctree-checkbox-12"  type = "checkbox" / > < label  class = "toctree-toggle"  for = "toctree-checkbox-12" > < i  class = "fa-solid fa-chevron-down" > < / i > < / label > < ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.utils.tree_flatten.html" > mlx.utils.tree_flatten< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.utils.tree_unflatten.html" > mlx.utils.tree_unflatten< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "../python/_autosummary/mlx.utils.tree_map.html" > mlx.utils.tree_map< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > C++ API Reference< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "nav bd-sidenav" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../cpp/ops.html" > Operations< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  aria-level = "2"  class = "caption"  role = "heading" > < span  class = "caption-text" > Further Reading< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "current nav bd-sidenav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 current active" > < a  class = "current reference internal"  href = "#" > Developer Documentation< / a > < / li > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / nav > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "sidebar-primary-items__end sidebar-primary__section" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  id = "rtd-footer-container" > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < main  id = "main-content"  class = "bd-main" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sbt-scroll-pixel-helper" > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < div  class = "bd-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < div  class = "bd-article-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              < div  class = "bd-header-article" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "header-article-items header-article__inner" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "header-article-items__start" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "header-article-item" > < label  class = "sidebar-toggle primary-toggle btn btn-sm"  for = "__primary"  title = "Toggle primary sidebar"  data-bs-placement = "bottom"  data-bs-toggle = "tooltip" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < span  class = "fa-solid fa-bars" > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / label > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "header-article-items__end" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "header-article-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "article-header-buttons" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< a  href = "https://github.com/ml-explore/mlx"  target = "_blank" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   class="btn btn-sm btn-source-repository-button"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   title="Source repository"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   data-bs-placement="bottom" data-bs-toggle="tooltip"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__icon-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < i  class = "fab fa-github" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "dropdown dropdown-download-buttons" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < button  class = "btn dropdown-toggle"  type = "button"  data-bs-toggle = "dropdown"  aria-expanded = "false"  aria-label = "Download this page" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < i  class = "fas fa-download" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < ul  class = "dropdown-menu" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > < a  href = "../_sources/dev/extensions.rst"  target = "_blank" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   class="btn btn-sm btn-download-source-button dropdown-item"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   title="Download source file"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   data-bs-placement="left" data-bs-toggle="tooltip"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__icon-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < i  class = "fas fa-file" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__text-container" > .rst< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< button  onclick = "window.print()" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  class="btn btn-sm btn-download-pdf-button dropdown-item"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  title="Print to PDF"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  data-bs-placement="left" data-bs-toggle="tooltip"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__icon-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < i  class = "fas fa-file-pdf" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__text-container" > .pdf< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< button  onclick = "toggleFullScreen()" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  class="btn btn-sm btn-fullscreen-button"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  title="Fullscreen mode"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  data-bs-placement="bottom" data-bs-toggle="tooltip"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								>
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "btn__icon-container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < i  class = "fas fa-expand" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								document.write(`
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < button  class = "btn btn-sm navbar-btn theme-switch-button"  title = "light/dark"  aria-label = "light/dark"  data-bs-placement = "bottom"  data-bs-toggle = "tooltip" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "theme-switch nav-link"  data-mode = "light" > < i  class = "fa-solid fa-sun fa-lg" > < / i > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "theme-switch nav-link"  data-mode = "dark" > < i  class = "fa-solid fa-moon fa-lg" > < / i > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "theme-switch nav-link"  data-mode = "auto" > < i  class = "fa-solid fa-circle-half-stroke fa-lg" > < / i > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								`);
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								document.write(`
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < button  class = "btn btn-sm navbar-btn search-button search-button__button"  title = "Search"  aria-label = "Search"  data-bs-placement = "bottom"  data-bs-toggle = "tooltip" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < i  class = "fa-solid fa-magnifying-glass fa-lg" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / button > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								`);
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< label  class = "sidebar-toggle secondary-toggle btn btn-sm"  for = "__secondary" title = "Toggle secondary sidebar"  data-bs-placement = "bottom"  data-bs-toggle = "tooltip" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "fa-solid fa-list" > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / label > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  id = "jb-print-docs-body"  class = "onlyprint" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < h1 > Developer Documentation< / h1 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    <!--  Table of contents  --> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  id = "print-main-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  id = "jb-print-toc" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < h2 >  Contents < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < nav  aria-label = "Page" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#introducing-the-example" > Introducing the Example< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#operations-and-primitives" > Operations and Primitives< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#operations" > Operations< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#primitives" > Primitives< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#using-the-primitives" > Using the Primitives< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-primitive" > Implementing the Primitive< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-cpu-backend" > Implementing the CPU Backend< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-gpu-backend" > Implementing the GPU Backend< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#primitive-transforms" > Primitive Transforms< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-and-binding" > Building and Binding< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#binding-to-python" > Binding to Python< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-with-cmake" > Building with CMake< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-with-setuptools" > Building with < code  class = "docutils literal notranslate" > < span  class = "pre" > setuptools< / span > < / code > < / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#usage" > Usage< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#results" > Results< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#scripts" > Scripts< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  id = "searchbox" > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < article  class = "bd-article"  role = "main" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                  
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < section  id = "developer-documentation" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h1 > Developer Documentation< a  class = "headerlink"  href = "#developer-documentation"  title = "Permalink to this heading" > #< / a > < / h1 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > MLX provides a open and flexible backend to which users may add operations
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and specialized implementations without much hassle. While the library supplies
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								efficient operations that can be used and composed for any number of
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								applications, there may arise cases where new functionalities or highly
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								optimized implementations are needed. For such cases, you may design and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								implement your own operations that link to and build on top of < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx.core< / span > < / code > .
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								We will introduce the inner-workings of MLX and go over a simple example to
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								learn the steps involved in adding new operations to MLX with your own CPU
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and GPU implementations.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "introducing-the-example" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Introducing the Example< a  class = "headerlink"  href = "#introducing-the-example"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s say that you would like an operation that takes in two arrays,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code > , scales them both by some coefficents < code  class = "docutils literal notranslate" > < span  class = "pre" > alpha< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > beta< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								respectively, and then adds them together to get the result
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > z< / span >  < span  class = "pre" > =< / span >  < span  class = "pre" > alpha< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > x< / span >  < span  class = "pre" > +< / span >  < span  class = "pre" > beta< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > y< / span > < / code > . Well, you can very easily do that by just
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								writing out a function as follows:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > mlx.core< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > mx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > simple_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > :< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > :< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span >  < span  class = "n" > alpha< / span > < span  class = "p" > :< / span >  < span  class = "nb" > float< / span > < span  class = "p" > ,< / span >  < span  class = "n" > beta< / span > < span  class = "p" > :< / span >  < span  class = "nb" > float< / span > < span  class = "p" > )< / span >  < span  class = "o" > -> < / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > alpha< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span >  < span  class = "o" > +< / span >  < span  class = "n" > beta< / span >  < span  class = "o" > *< / span >  < span  class = "n" > y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This function performs that operation while leaving the implementations and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								differentiation to MLX.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > However, you work with vector math libraries often and realize that the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code >  routine defines the same operation < code  class = "docutils literal notranslate" > < span  class = "pre" > Y< / span >  < span  class = "pre" > =< / span >  < span  class = "pre" > (alpha< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > X)< / span >  < span  class = "pre" > +< / span >  < span  class = "pre" > (beta< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > Y)< / span > < / code > .
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								You would really like the part of your applications that does this operation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								on the CPU to be very fast - so you decide that you want it to rely on the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code >  routine provided by the < a  class = "reference external"  href = "https://developer.apple.com/documentation/accelerate/blas?language=objc" > Accelerate< / a >  framework. Continuing to impose
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								our assumptions on to you, let’  s also assume that you want to learn how add
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								your own implementation for the gradients of your new operation while going
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								over the ins-and-outs of the MLX framework.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Well, what a coincidence! You are in the right place. Over the course of this
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								example, we will learn:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > The structure of the MLX library from the frontend API to the backend implementations.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > How to implement your own CPU backend that redirects to < a  class = "reference external"  href = "https://developer.apple.com/documentation/accelerate/blas?language=objc" > Accelerate< / a >  when appropriate (and a fallback if needed).< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > How to implement your own GPU implementation using metal.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > How to add your own < code  class = "docutils literal notranslate" > < span  class = "pre" > vjp< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > jvp< / span > < / code > .< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > How to build your implementations, link them to MLX, and bind them to python.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "operations-and-primitives" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Operations and Primitives< a  class = "headerlink"  href = "#operations-and-primitives"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > In one sentence, operations in MLX build the computation graph, and primitives
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								provide the rules for evaluation and transformations of said graph. Let’  s start
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								by discussing operations in more detail.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "operations" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Operations< a  class = "headerlink"  href = "#operations"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Operations are the frontend functions that operate on arrays. They are defined
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								in the C++ API (< a  class = "reference internal"  href = "../cpp/ops.html#cpp-ops" > < span  class = "std std-ref" > Operations< / span > < / a > ) and then we provide bindings to these
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								operations in the Python API (< a  class = "reference internal"  href = "../python/ops.html#ops" > < span  class = "std std-ref" > Operations< / span > < / a > ).< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We would like an operation, < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > axpby()< / span > < / code >  that takes in two arrays < code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code > ,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and two scalars, < code  class = "docutils literal notranslate" > < span  class = "pre" > alpha< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > beta< / span > < / code > . This is how we would define it in the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								C++ API:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /**< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > *  Scale and sum two vectors elementwise< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > *  z = alpha * x + beta * y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > *< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > *  Follow numpy style broadcasting between x and y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > *  Inputs are upcasted to floats if needed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" > **/< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > array< / span > < span  class = "w" >  < / span > < span  class = "nf" > axpby< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > StreamOrDevice< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "p" > {}< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Stream on which to schedule the operation< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This operation itself can call other operations within it if needed. So, the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								simplest way to go about implementing this operation would be do so in terms
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of existing operations.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > array< / span > < span  class = "w" >  < / span > < span  class = "nf" > axpby< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > StreamOrDevice< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "w" >  < / span > < span  class = "cm" > /* = {} */< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Stream on which to schedule the operation< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Scale x and y on the provided stream< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > ax< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > multiply< / span > < span  class = "p" > (< / span > < span  class = "n" > array< / span > < span  class = "p" > (< / span > < span  class = "n" > alpha< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > by< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > multiply< / span > < span  class = "p" > (< / span > < span  class = "n" > array< / span > < span  class = "p" > (< / span > < span  class = "n" > beta< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Add and return< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > add< / span > < span  class = "p" > (< / span > < span  class = "n" > ax< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > by< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > However, as we discussed earlier, this is not our goal. The operations themselves
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								do not contain the implementations that act on the data, nor do they contain the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								rules of transformations. Rather, they are an easy to use interface that build
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								on top of the building blocks we call < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "primitives" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Primitives< a  class = "headerlink"  href = "#primitives"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > A < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code >  is part of the computation graph of an < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code > . It
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								defines how to create an output given a set of input < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code >  . Further,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								a < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code >  is a class that contains rules on how it is evaluated
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								on the CPU or GPU, and how it acts under transformations such as < code  class = "docutils literal notranslate" > < span  class = "pre" > vjp< / span > < / code >  and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > jvp< / span > < / code > . These words on their own can be a bit abstract, so lets take a step
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								back and go to our example to give ourselves a more concrete image.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "k" > class< / span > < span  class = "w" >  < / span > < span  class = "nc" > Axpby< / span > < span  class = "w" >  < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "k" > public< / span > < span  class = "w" >  < / span > < span  class = "n" > Primitive< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >   < / span > < span  class = "k" > public< / span > < span  class = "o" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > explicit< / span > < span  class = "w" >  < / span > < span  class = "n" > Axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > Stream< / span > < span  class = "w" >  < / span > < span  class = "n" > stream< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "n" > Primitive< / span > < span  class = "p" > (< / span > < span  class = "n" > stream< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > (< / span > < span  class = "n" > alpha< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > (< / span > < span  class = "n" > beta< / span > < span  class = "p" > ){};< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /**< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * A primitive must know how to evaluate itself on the CPU/GPU< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * for the given inputs and populate the output array.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     *< / span > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-06 08:13:20 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * To avoid unnecessary allocations, the evaluation function< / span > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * is responsible for allocating space for the array.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > eval_cpu< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > eval_gpu< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /** The Jacobian-vector product. */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > array< / span > < span  class = "w" >  < / span > < span  class = "nf" > jvp< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > primals< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > tangents< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /** The vector-Jacobian product. */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "n" > vjp< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > primals< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > cotan< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /**< / span > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * The primitive must know how to vectorize itself across< / span > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * the given axes. The output is a pair containing the array< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * representing the vectorized computation and the axis which< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     * corresponds to the output vectorized dimension.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cm" >     */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > pair< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "kt" > int< / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "n" > vmap< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > axes< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /** Print the primitive. */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > print< / span > < span  class = "p" > (< / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > ostream< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > os< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > os< / span > < span  class = "w" >  < / span > < span  class = "o" > < < < / span > < span  class = "w" >  < / span > < span  class = "s" > " Axpby" < / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /** Equivalence check **/< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > bool< / span > < span  class = "w" >  < / span > < span  class = "nf" > is_equivalent< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > Primitive< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > other< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "k" > override< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >   < / span > < span  class = "k" > private< / span > < span  class = "o" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "cm" > /** Fall back implementation for evaluation on CPU */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > eval< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > };< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > The < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code >  class derives from the base < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code >  class and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								follows the above demonstrated interface. < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code >  treats < code  class = "docutils literal notranslate" > < span  class = "pre" > alpha< / span > < / code >  and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > beta< / span > < / code >  as parameters. It then provides implementations of how the array < code  class = "docutils literal notranslate" > < span  class = "pre" > out< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								is produced given < code  class = "docutils literal notranslate" > < span  class = "pre" > inputs< / span > < / code >  through < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_cpu()< / span > < / code >  and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_gpu()< / span > < / code > . Further, it provides rules of transformations in
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::jvp()< / span > < / code > , < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::vjp()< / span > < / code > , and < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::vmap()< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "using-the-primitives" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Using the Primitives< a  class = "headerlink"  href = "#using-the-primitives"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Operations can use this < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code >  to add a new < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code >  to
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								the computation graph. An < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code >  can be constructed by providing its
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								data type, shape, the < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code >  that computes it, and the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code >  inputs that are passed to the primitive.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s re-implement our operation now in terms of our < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code >  primitive.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > array< / span > < span  class = "w" >  < / span > < span  class = "nf" > axpby< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Input array y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Scaling factor for y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > StreamOrDevice< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "w" >  < / span > < span  class = "cm" > /* = {} */< / span > < span  class = "w" >  < / span > < span  class = "c1" > // Stream on which to schedule the operation< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Promote dtypes between x and y as needed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > promoted_dtype< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > promote_types< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Upcast to float32 for non-floating point inputs x and y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > out_dtype< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > is_floating_point< / span > < span  class = "p" > (< / span > < span  class = "n" > promoted_dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "o" > ?< / span > < span  class = "w" >  < / span > < span  class = "n" > promoted_dtype< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "n" > promote_types< / span > < span  class = "p" > (< / span > < span  class = "n" > promoted_dtype< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > float32< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Cast x and y up to the determined dtype (on the same stream s)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > x_casted< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > astype< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out_dtype< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > y_casted< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > astype< / span > < span  class = "p" > (< / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out_dtype< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Broadcast the shapes of x and y (on the same stream s)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > broadcasted_inputs< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > broadcast_arrays< / span > < span  class = "p" > ({< / span > < span  class = "n" > x_casted< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y_casted< / span > < span  class = "p" > },< / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > out_shape< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > broadcasted_inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ].< / span > < span  class = "n" > shape< / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Construct the array as the output of the Axpby primitive< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // with the broadcasted and upcasted arrays as inputs< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* const std::vector< int> &  shape = */< / span > < span  class = "w" >  < / span > < span  class = "n" > out_shape< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* Dtype dtype = */< / span > < span  class = "w" >  < / span > < span  class = "n" > out_dtype< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* std::unique_ptr< Primitive>  primitive = */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > make_unique< / span > < span  class = "o" > < < / span > < span  class = "n" > Axpby< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > to_stream< / span > < span  class = "p" > (< / span > < span  class = "n" > s< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* const std::vector< array> &  inputs = */< / span > < span  class = "w" >  < / span > < span  class = "n" > broadcasted_inputs< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This operation now handles the following:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ol  class = "arabic simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Upcast inputs and resolve the the output data type.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Broadcast the inputs and resolve the output shape.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Construct the primitive < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code >  using the given stream, < code  class = "docutils literal notranslate" > < span  class = "pre" > alpha< / span > < / code > , and < code  class = "docutils literal notranslate" > < span  class = "pre" > beta< / span > < / code > .< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Construct the output < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > array< / span > < / code >  using the primitive and the inputs.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ol > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "implementing-the-primitive" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Implementing the Primitive< a  class = "headerlink"  href = "#implementing-the-primitive"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > No computation happens when we call the operation alone. In effect, the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								operation only builds the computation graph. When we evaluate the output
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								array, MLX schedules the execution of the computation graph, and calls
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_cpu()< / span > < / code >  or < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_gpu()< / span > < / code >  depending on the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								stream/device specified by the user.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition warning" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Warning< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > When < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Primitive::eval_cpu()< / span > < / code >  or < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Primitive::eval_gpu()< / span > < / code >  are called,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								no memory has been allocated for the output array. It falls on the implementation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of these functions to allocate memory as needed< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "implementing-the-cpu-backend" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Implementing the CPU Backend< a  class = "headerlink"  href = "#implementing-the-cpu-backend"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s start by trying to implement a naive and generic version of
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_cpu()< / span > < / code > . We declared this as a private member function of
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code >  earlier called < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval()< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Our naive method will go over each element of the output array, find the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								corresponding input elements of < code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code >  and perform the operation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								pointwise. This is captured in the templated function < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > axpby_impl()< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "k" > template< / span > < span  class = "w" >  < / span > < span  class = "o" > < < / span > < span  class = "k" > typename< / span > < span  class = "w" >  < / span > < span  class = "nc" > T< / span > < span  class = "o" > > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // We only allocate memory when we are ready to fill the output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // malloc_or_wait synchronously allocates available memory< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // There may be a wait executed here if the allocation is requested< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // under memory-pressured conditions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > set_data< / span > < span  class = "p" > (< / span > < span  class = "n" > allocator< / span > < span  class = "o" > ::< / span > < span  class = "n" > malloc_or_wait< / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > nbytes< / span > < span  class = "p" > ()));< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Collect input and output data pointers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x_ptr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > data< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y_ptr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > data< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > out_ptr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > data< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Cast alpha and beta to the relevant types< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > alpha_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Do the elementwise operation for each output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > for< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "kt" > size_t< / span > < span  class = "w" >  < / span > < span  class = "n" > out_idx< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "mi" > 0< / span > < span  class = "p" > ;< / span > < span  class = "w" >  < / span > < span  class = "n" > out_idx< / span > < span  class = "w" >  < / span > < span  class = "o" > < < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ();< / span > < span  class = "w" >  < / span > < span  class = "n" > out_idx< / span > < span  class = "o" > ++< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "c1" > // Map linear indices to offsets in x and y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > x_offset< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > elem_to_loc< / span > < span  class = "p" > (< / span > < span  class = "n" > out_idx< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > strides< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > y_offset< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > elem_to_loc< / span > < span  class = "p" > (< / span > < span  class = "n" > out_idx< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > strides< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "c1" > // We allocate the output to be contiguous and regularly strided< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "c1" > // (defaults to row major) and hence it doesn' t need additonal mapping< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > out_ptr< / span > < span  class = "p" > [< / span > < span  class = "n" > out_idx< / span > < span  class = "p" > ]< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x_ptr< / span > < span  class = "p" > [< / span > < span  class = "n" > x_offset< / span > < span  class = "p" > ]< / span > < span  class = "w" >  < / span > < span  class = "o" > +< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y_ptr< / span > < span  class = "p" > [< / span > < span  class = "n" > y_offset< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Now, we would like our implementation to be able to do this pointwise operation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								for all incoming floating point arrays. Accordingly, we add dispatches for
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > float32< / span > < / code > , < code  class = "docutils literal notranslate" > < span  class = "pre" > float16< / span > < / code > , < code  class = "docutils literal notranslate" > < span  class = "pre" > bfloat16< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > complex64< / span > < / code > . We throw an error
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								if we encounter an unexpected type.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** Fall back implementation for evaluation on CPU */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > Axpby::eval< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Check the inputs (registered in the op while contructing the out array)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > assert< / span > < span  class = "p" > (< / span > < span  class = "n" > inputs< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mi" > 2< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Dispatch to the correct dtype< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl< / span > < span  class = "o" > < < / span > < span  class = "kt" > float< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > < span  class = "w" >  < / span > < span  class = "k" > else< / span > < span  class = "w" >  < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl< / span > < span  class = "o" > < < / span > < span  class = "n" > float16_t< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > < span  class = "w" >  < / span > < span  class = "k" > else< / span > < span  class = "w" >  < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "n" > bfloat16< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl< / span > < span  class = "o" > < < / span > < span  class = "n" > bfloat16_t< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > < span  class = "w" >  < / span > < span  class = "k" > else< / span > < span  class = "w" >  < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "n" > complex64< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl< / span > < span  class = "o" > < < / span > < span  class = "n" > complex64_t< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > < span  class = "w" >  < / span > < span  class = "k" > else< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > throw< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > runtime_error< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >             < / span > < span  class = "s" > " Axpby is only supported for floating point types." < / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We have a fallback implementation! Now, to do what we are really here to do.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								Remember we wanted to use the < code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code >  routine provided by the < a  class = "reference external"  href = "https://developer.apple.com/documentation/accelerate/blas?language=objc" > Accelerate< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								framework? Well, there are 3 complications to keep in mind:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ol  class = "arabic simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Accelerate does not provide implementations of < code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code >  for half precision
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								floats. We can only direct to it for < code  class = "docutils literal notranslate" > < span  class = "pre" > float32< / span > < / code >  types< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Accelerate assumes the inputs < code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code >  are contiguous and all elements
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								have fixed strides between them. Possibly due to broadcasts and transposes,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								we aren’  t guaranteed that the inputs fit this requirement. We can
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								only direct to Accelerate if both < code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code >  are row contiguous or
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								column contiguous.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Accelerate performs the routine < code  class = "docutils literal notranslate" > < span  class = "pre" > Y< / span >  < span  class = "pre" > =< / span >  < span  class = "pre" > (alpha< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > X)< / span >  < span  class = "pre" > +< / span >  < span  class = "pre" > (beta< / span >  < span  class = "pre" > *< / span >  < span  class = "pre" > Y)< / span > < / code >  inplace.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								MLX expects to write out the answer to a new array. We must copy the elements
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code >  into the output array and use that as an input to < code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code > < / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ol > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s write out an implementation that uses Accelerate in the right conditions.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								It must simply allocate data for the output, copy elements of < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code >  into it,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and then call the < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > catlas_saxpby()< / span > < / code >  from accelerate.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "k" > template< / span > < span  class = "w" >  < / span > < span  class = "o" > < < / span > < span  class = "k" > typename< / span > < span  class = "w" >  < / span > < span  class = "nc" > T< / span > < span  class = "o" > > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_impl_accelerate< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "kt" > float< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Accelerate library provides catlas_saxpby which does< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Y = (alpha * X) + (beta * Y) in place< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // To use it, we first copy the data in y over to the output array< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // This specialization requires both x and y be contiguous in the same mode< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // i.e: corresponding linear indices in both point to corresponding elements< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // The data in the output array is allocated to match the strides in y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // such that x, y, and out are contiguous in the same mode and< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // no transposition is needed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > set_data< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > allocator< / span > < span  class = "o" > ::< / span > < span  class = "n" > malloc_or_wait< / span > < span  class = "p" > (< / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > data_size< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > itemsize< / span > < span  class = "p" > ()),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > data_size< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > strides< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > flags< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // We then copy over the elements using the contiguous vector specialization< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > copy_inplace< / span > < span  class = "p" > (< / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > CopyType< / span > < span  class = "o" > ::< / span > < span  class = "n" > Vector< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Get x and y pointers for catlas_saxpby< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x_ptr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > data< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y_ptr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > data< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > alpha_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > T< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Call the inplace accelerate operator< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > catlas_saxpby< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* N = */< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* ALPHA = */< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* X = */< / span > < span  class = "w" >  < / span > < span  class = "n" > x_ptr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* INCX = */< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* BETA = */< / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* Y = */< / span > < span  class = "w" >  < / span > < span  class = "n" > y_ptr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "cm" > /* INCY = */< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Great! But what about the inputs that do not fit the criteria for accelerate?
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								Luckily, we can always just direct back to < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval()< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > With this in mind, lets finally implement our < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_cpu()< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** Evaluate primitive on CPU using accelerate specializations */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > Axpby::eval_cpu< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > assert< / span > < span  class = "p" > (< / span > < span  class = "n" > inputs< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mi" > 2< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Accelerate specialization for contiguous single precision float arrays< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "n" > float32< / span > < span  class = "w" >  < / span > < span  class = "o" > & & < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "p" > ((< / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > flags< / span > < span  class = "p" > ().< / span > < span  class = "n" > row_contiguous< / span > < span  class = "w" >  < / span > < span  class = "o" > & & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > flags< / span > < span  class = "p" > ().< / span > < span  class = "n" > row_contiguous< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "o" > ||< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > flags< / span > < span  class = "p" > ().< / span > < span  class = "n" > col_contiguous< / span > < span  class = "w" >  < / span > < span  class = "o" > & & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > flags< / span > < span  class = "p" > ().< / span > < span  class = "n" > col_contiguous< / span > < span  class = "p" > )))< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > axpby_impl_accelerate< / span > < span  class = "o" > < < / span > < span  class = "kt" > float< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Fall back to common backend if specializations are not available< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > eval< / span > < span  class = "p" > (< / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We have now hit a milestone! Just this much is enough to run the operation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > axpby()< / span > < / code >  on a CPU stream!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > If you do not plan on running the operation on the GPU or using transforms on
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								computation graphs that contain < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Axpby< / span > < / code > , you can stop implementing the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								primitive here and enjoy the speed-ups you get from the Accelerate library.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "implementing-the-gpu-backend" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Implementing the GPU Backend< a  class = "headerlink"  href = "#implementing-the-gpu-backend"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Apple silicon devices address their GPUs using the < a  class = "reference external"  href = "https://developer.apple.com/documentation/metal?language=objc" > Metal< / a >  shading language, and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								all GPU kernels in MLX are written using metal.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Here are some helpful resources if you are new to metal!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > A walkthrough of the metal compute pipeline: < a  class = "reference external"  href = "https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc" > Metal Example< / a > < / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Documentation for metal shading language: < a  class = "reference external"  href = "https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf" > Metal Specification< / a > < / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Using metal from C++: < a  class = "reference external"  href = "https://developer.apple.com/metal/cpp/" > Metal-cpp< / a > < / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s keep the GPU algorithm simple. We will launch exactly as many threads
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								as there are elements in the output. Each thread will pick the element it needs
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								from < code  class = "docutils literal notranslate" > < span  class = "pre" > x< / span > < / code >  and < code  class = "docutils literal notranslate" > < span  class = "pre" > y< / span > < / code > , do the pointwise operation, and then update its assigned
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								element in the output.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "k" > template< / span > < span  class = "w" >  < / span > < span  class = "o" > < < / span > < span  class = "k" > typename< / span > < span  class = "w" >  < / span > < span  class = "nc" > T< / span > < span  class = "o" > > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > [[< / span > < span  class = "n" > kernel< / span > < span  class = "p" > ]]< / span > < span  class = "w" >  < / span > < span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby_general< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > device< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > device< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > device< / span > < span  class = "w" >  < / span > < span  class = "n" > T< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 2< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > alpha< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 3< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > beta< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 4< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > int< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > shape< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 5< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > size_t< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x_strides< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 6< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > size_t< / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y_strides< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 7< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > constant< / span > < span  class = "w" >  < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "kt" > int< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > buffer< / span > < span  class = "p" > (< / span > < span  class = "mi" > 8< / span > < span  class = "p" > )]],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > uint< / span > < span  class = "w" >  < / span > < span  class = "n" > index< / span > < span  class = "w" >  < / span > < span  class = "p" > [[< / span > < span  class = "n" > thread_position_in_grid< / span > < span  class = "p" > ]])< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Convert linear indices to offsets in array< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > x_offset< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > elem_to_loc< / span > < span  class = "p" > (< / span > < span  class = "n" > index< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > shape< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > x_strides< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > y_offset< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > elem_to_loc< / span > < span  class = "p" > (< / span > < span  class = "n" > index< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > shape< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y_strides< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Do the operation and update the output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > out< / span > < span  class = "p" > [< / span > < span  class = "n" > index< / span > < span  class = "p" > ]< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > alpha< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > [< / span > < span  class = "n" > x_offset< / span > < span  class = "p" > ]< / span > < span  class = "w" >  < / span > < span  class = "o" > +< / span > < span  class = "w" >  < / span > < span  class = "k" > static_cast< / span > < span  class = "o" > < < / span > < span  class = "n" > T< / span > < span  class = "o" > > < / span > < span  class = "p" > (< / span > < span  class = "n" > beta< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > [< / span > < span  class = "n" > y_offset< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We then need to instantiate this template for all floating point types and give
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								each instantiation a unique host name so we can identify the right kernel for
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								each data type.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cp" > #define instantiate_axpby(type_name, type)              \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >     template [[host_name(" axpby_general_"  #type_name)]] \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >     [[kernel]] void axpby_general< type> (                \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         device const type* x [[buffer(0)]],             \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         device const type* y [[buffer(1)]],             \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         device type* out [[buffer(2)]],                 \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const float&  alpha [[buffer(3)]],      \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const float&  beta [[buffer(4)]],       \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const int* shape [[buffer(5)]],        \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const size_t* x_strides [[buffer(6)]], \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const size_t* y_strides [[buffer(7)]], \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         constant const int&  ndim [[buffer(8)]],         \< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "cp" >         uint index [[thread_position_in_grid]]);< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > instantiate_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > float32< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "kt" > float< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > instantiate_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > float16< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > half< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > instantiate_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > bflot16< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > bfloat16_t< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > instantiate_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > complex64< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > complex64_t< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This kernel will be compiled into a metal library < code  class = "docutils literal notranslate" > < span  class = "pre" > mlx_ext.metallib< / span > < / code >  as we
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								will see later in < a  class = "reference internal"  href = "#building-with-cmake" > < span  class = "std std-ref" > Building with CMake< / span > < / a > . In the following example, we
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								assume that the library < code  class = "docutils literal notranslate" > < span  class = "pre" > mlx_ext.metallib< / span > < / code >  will always be co-located with
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								the executable/ shared-library calling the < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > register_library()< / span > < / code >  function.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								The < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > register_library()< / span > < / code >  function takes the library’  s name and potential
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								path (or in this case, a function that can produce the path of the metal
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								library) and tries to load that library if it hasn’  t already been registered
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								by the relevant static < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > mlx::core::metal::Device< / span > < / code >  object. This is why,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								it is important to package your C++ library with the metal library. We will
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								go over this process in more detail later.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > The logic to determine the kernel, set the inputs, resolve the grid dimensions
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and dispatch it to the GPU are contained in < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::eval_gpu()< / span > < / code >  as shown
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								below.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** Evaluate primitive on GPU */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kt" > void< / span > < span  class = "w" >  < / span > < span  class = "nf" > Axpby::eval_gpu< / span > < span  class = "p" > (< / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Prepare inputs< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > assert< / span > < span  class = "p" > (< / span > < span  class = "n" > inputs< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mi" > 2< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > [< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ];< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Each primitive carries the stream it should execute on< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // and each stream carries its device identifiers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > s< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > stream< / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // We get the needed metal device using the stream< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > d< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > metal< / span > < span  class = "o" > ::< / span > < span  class = "n" > device< / span > < span  class = "p" > (< / span > < span  class = "n" > s< / span > < span  class = "p" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Allocate output memory< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > set_data< / span > < span  class = "p" > (< / span > < span  class = "n" > allocator< / span > < span  class = "o" > ::< / span > < span  class = "n" > malloc_or_wait< / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > nbytes< / span > < span  class = "p" > ()));< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Resolve name of kernel (corresponds to axpby.metal)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > ostringstream< / span > < span  class = "w" >  < / span > < span  class = "n" > kname< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > kname< / span > < span  class = "w" >  < / span > < span  class = "o" > < < < / span > < span  class = "w" >  < / span > < span  class = "s" > " axpby_" < / span > < span  class = "w" >  < / span > < span  class = "o" > < < < / span > < span  class = "w" >  < / span > < span  class = "s" > " general_" < / span > < span  class = "w" >  < / span > < span  class = "o" > < < < / span > < span  class = "w" >  < / span > < span  class = "n" > type_to_name< / span > < span  class = "p" > (< / span > < span  class = "n" > out< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Make sure the metal library is available and look for it< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // in the same folder as this executable if needed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > d< / span > < span  class = "p" > .< / span > < span  class = "n" > register_library< / span > < span  class = "p" > (< / span > < span  class = "s" > " mlx_ext" < / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > metal< / span > < span  class = "o" > ::< / span > < span  class = "n" > get_colocated_mtllib_path< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Make a kernel from this metal library< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > kernel< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > d< / span > < span  class = "p" > .< / span > < span  class = "n" > get_kernel< / span > < span  class = "p" > (< / span > < span  class = "n" > kname< / span > < span  class = "p" > .< / span > < span  class = "n" > str< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "s" > " mlx_ext" < / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Prepare to encode kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > compute_encoder< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > d< / span > < span  class = "p" > .< / span > < span  class = "n" > get_command_encoder< / span > < span  class = "p" > (< / span > < span  class = "n" > s< / span > < span  class = "p" > .< / span > < span  class = "n" > index< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setComputePipelineState< / span > < span  class = "p" > (< / span > < span  class = "n" > kernel< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Kernel parameters are registered with buffer indices corresponding to< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // those in the kernel decelaration at axpby.metal< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > int< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > ndim< / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > size_t< / span > < span  class = "w" >  < / span > < span  class = "n" > nelem< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ();< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Encode input arrays to kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > set_array_buffer< / span > < span  class = "p" > (< / span > < span  class = "n" > compute_encoder< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 0< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > set_array_buffer< / span > < span  class = "p" > (< / span > < span  class = "n" > compute_encoder< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > y< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Encode output arrays to kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > set_array_buffer< / span > < span  class = "p" > (< / span > < span  class = "n" > compute_encoder< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > out< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 2< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Encode alpha and beta< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "o" > & < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > float< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 3< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "o" > & < / span > < span  class = "n" > beta_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > float< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 4< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Encode shape, strides and ndim< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > ().< / span > < span  class = "n" > data< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > int< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 5< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > .< / span > < span  class = "n" > strides< / span > < span  class = "p" > ().< / span > < span  class = "n" > data< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > size_t< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 6< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "n" > y< / span > < span  class = "p" > .< / span > < span  class = "n" > strides< / span > < span  class = "p" > ().< / span > < span  class = "n" > data< / span > < span  class = "p" > (),< / span > < span  class = "w" >  < / span > < span  class = "n" > ndim< / span > < span  class = "w" >  < / span > < span  class = "o" > *< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > size_t< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 7< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > setBytes< / span > < span  class = "p" > (< / span > < span  class = "o" > & < / span > < span  class = "n" > ndim< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "k" > sizeof< / span > < span  class = "p" > (< / span > < span  class = "kt" > int< / span > < span  class = "p" > ),< / span > < span  class = "w" >  < / span > < span  class = "mi" > 8< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // We launch 1 thread for each input and make sure that the number of< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // threads in any given threadgroup is not higher than the max allowed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "kt" > size_t< / span > < span  class = "w" >  < / span > < span  class = "n" > tgp_size< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > min< / span > < span  class = "p" > (< / span > < span  class = "n" > nelem< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > kernel< / span > < span  class = "o" > -> < / span > < span  class = "n" > maxTotalThreadsPerThreadgroup< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Fix the 3D size of each threadgroup (in terms of threads)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > MTL< / span > < span  class = "o" > ::< / span > < span  class = "n" > Size< / span > < span  class = "w" >  < / span > < span  class = "n" > group_dims< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > MTL< / span > < span  class = "o" > ::< / span > < span  class = "n" > Size< / span > < span  class = "p" > (< / span > < span  class = "n" > tgp_size< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Fix the 3D size of the launch grid (in terms of threads)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > MTL< / span > < span  class = "o" > ::< / span > < span  class = "n" > Size< / span > < span  class = "w" >  < / span > < span  class = "n" > grid_dims< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > MTL< / span > < span  class = "o" > ::< / span > < span  class = "n" > Size< / span > < span  class = "p" > (< / span > < span  class = "n" > nelem< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Launch the grid with the given number of threads divded among< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // the given threadgroups< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > compute_encoder< / span > < span  class = "o" > -> < / span > < span  class = "n" > dispatchThreads< / span > < span  class = "p" > (< / span > < span  class = "n" > grid_dims< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > group_dims< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We can now call the < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > axpby()< / span > < / code >  operation on both the CPU and the GPU!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > A few things to note about MLX and metal before moving on. MLX keeps track
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of the active < code  class = "docutils literal notranslate" > < span  class = "pre" > compute_encoder< / span > < / code > . We rely on < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > d.get_command_encoder()< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								to give us the active metal compute command encoder instead of building a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								new one and calling < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > compute_encoder-> end_encoding()< / span > < / code >  at the end.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								MLX keeps adding kernels (compute pipelines) to the active command encoder
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								until some specified limit is hit or the compute encoder needs to be flushed
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								for synchronization. MLX also handles enqueuing and commiting the associated
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								command buffers as needed. We suggest taking a deeper dive into
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > metal::Device< / span > < / code >  if you would like to study this routine further.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "primitive-transforms" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Primitive Transforms< a  class = "headerlink"  href = "#primitive-transforms"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Now that we have come this far, let’  s also learn how to add implementations to
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								transformations in a < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code > . These transformations can be built on
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								top of our operations, including the one we just defined now. Which then gives
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								us the following < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::jvp()< / span > < / code >  and < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > Axpby::vjp()< / span > < / code >  implementations.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** The Jacobian-vector product. */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > array< / span > < span  class = "w" >  < / span > < span  class = "nf" > Axpby::jvp< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > primals< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > tangents< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Forward mode diff that pushes along the tangents< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // The jvp transform on the the primitive can built with ops< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // that are scheduled on the same stream as the primtive< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // If argnums = {0}, we only push along x in which case the< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // jvp is just the tangent scaled by alpha< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Similarly, if argnums = {1}, the jvp is just the tangent< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // scaled by beta< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > if< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "n" > argnums< / span > < span  class = "p" > .< / span > < span  class = "n" > size< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "mi" > 1< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > scale< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ]< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mi" > 0< / span > < span  class = "w" >  < / span > < span  class = "o" > ?< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "w" >  < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > scale_arr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "p" > (< / span > < span  class = "n" > scale< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > tangents< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ].< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > multiply< / span > < span  class = "p" > (< / span > < span  class = "n" > scale_arr< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > tangents< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ],< / span > < span  class = "w" >  < / span > < span  class = "n" > stream< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // If, argnums = {0, 1}, we take contributions from both< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // which gives us jvp = tangent_x * alpha + tangent_y * beta< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > else< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > tangents< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ],< / span > < span  class = "w" >  < / span > < span  class = "n" > tangents< / span > < span  class = "p" > [< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ],< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > stream< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** The vector-Jacobian product. */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "n" > Axpby< / span > < span  class = "o" > ::< / span > < span  class = "n" > vjp< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > primals< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "o" > & < / span > < span  class = "w" >  < / span > < span  class = "n" > cotan< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "c1" > // Reverse mode diff< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "n" > vjps< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > for< / span > < span  class = "w" >  < / span > < span  class = "p" > (< / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > arg< / span > < span  class = "w" >  < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "n" > argnums< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > scale< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > arg< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mi" > 0< / span > < span  class = "w" >  < / span > < span  class = "o" > ?< / span > < span  class = "w" >  < / span > < span  class = "n" > alpha_< / span > < span  class = "w" >  < / span > < span  class = "o" > :< / span > < span  class = "w" >  < / span > < span  class = "n" > beta_< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > auto< / span > < span  class = "w" >  < / span > < span  class = "n" > scale_arr< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > array< / span > < span  class = "p" > (< / span > < span  class = "n" > scale< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > cotan< / span > < span  class = "p" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ());< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > vjps< / span > < span  class = "p" > .< / span > < span  class = "n" > push_back< / span > < span  class = "p" > (< / span > < span  class = "n" > multiply< / span > < span  class = "p" > (< / span > < span  class = "n" > scale_arr< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > cotan< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > stream< / span > < span  class = "p" > ()));< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > return< / span > < span  class = "w" >  < / span > < span  class = "n" > vjps< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Finally, you need not have a transformation fully defined to start using your
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								own < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > Primitive< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "cm" > /** Vectorize primitve along given axis */< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > pair< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "kt" > int< / span > < span  class = "o" > > < / span > < span  class = "w" >  < / span > < span  class = "n" > Axpby< / span > < span  class = "o" > ::< / span > < span  class = "n" > vmap< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "n" > array< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > inputs< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "k" > const< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > vector< / span > < span  class = "o" > < < / span > < span  class = "kt" > int< / span > < span  class = "o" > > & < / span > < span  class = "w" >  < / span > < span  class = "n" > axes< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "k" > throw< / span > < span  class = "w" >  < / span > < span  class = "n" > std< / span > < span  class = "o" > ::< / span > < span  class = "n" > runtime_error< / span > < span  class = "p" > (< / span > < span  class = "s" > " Axpby has no vmap implementation." < / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "building-and-binding" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Building and Binding< a  class = "headerlink"  href = "#building-and-binding"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s look at the overall directory structure first.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line-block" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > extensions< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > ├── axpby< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   ├── axpby.cpp< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   ├── axpby.h< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   └── axpby.metal< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > ├── mlx_sample_extensions< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   └── __init__.py< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > ├── bindings.cpp< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > ├── CMakeLists.txt< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > └── setup.py< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/axpby/< / span > < / code >  defines the C++ extension library< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/mlx_sample_extensions< / span > < / code >  sets out the strucutre for the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								associated python package< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/bindings.cpp< / span > < / code >  provides python bindings for our operation< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/CMakeLists.txt< / span > < / code >  holds CMake rules to build the library and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								python bindings< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/setup.py< / span > < / code >  holds the < code  class = "docutils literal notranslate" > < span  class = "pre" > setuptools< / span > < / code >  rules to build and install
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								the python package< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "binding-to-python" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Binding to Python< a  class = "headerlink"  href = "#binding-to-python"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We use < a  class = "reference external"  href = "https://pybind11.readthedocs.io/en/stable/" > PyBind11< / a >  to build a Python API for the C++ library. Since bindings
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								for all needed components such as < cite > mlx.core.array< / cite > , < cite > mlx.core.stream< / cite > , etc.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								are already provided, adding our < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > axpby()< / span > < / code >  becomes very simple!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-C++ notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > PYBIND11_MODULE< / span > < span  class = "p" > (< / span > < span  class = "n" > mlx_sample_extensions< / span > < span  class = "p" > ,< / span > < span  class = "w" >  < / span > < span  class = "n" > m< / span > < span  class = "p" > )< / span > < span  class = "w" >  < / span > < span  class = "p" > {< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > m< / span > < span  class = "p" > .< / span > < span  class = "n" > doc< / span > < span  class = "p" > ()< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "s" > " Sample C++ and metal extensions for MLX" < / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "n" > m< / span > < span  class = "p" > .< / span > < span  class = "n" > def< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " axpby" < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "o" > & < / span > < span  class = "n" > axpby< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " x" < / span > < span  class = "n" > _a< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " y" < / span > < span  class = "n" > _a< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > py< / span > < span  class = "o" > ::< / span > < span  class = "n" > pos_only< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " alpha" < / span > < span  class = "n" > _a< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " beta" < / span > < span  class = "n" > _a< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "n" > py< / span > < span  class = "o" > ::< / span > < span  class = "n" > kw_only< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "s" > " stream" < / span > < span  class = "n" > _a< / span > < span  class = "w" >  < / span > < span  class = "o" > =< / span > < span  class = "w" >  < / span > < span  class = "n" > py< / span > < span  class = "o" > ::< / span > < span  class = "n" > none< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >         < / span > < span  class = "sa" > R< / span > < span  class = "s" > " < / span > < span  class = "dl" > pbdoc(< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             Scale and sum two vectors elementwise< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             ``z = alpha * x + beta * y``< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             Follows numpy style broadcasting between ``x`` and ``y``< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             Inputs are upcasted to floats if needed< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             Args:< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >                 x (array): Input array.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >                 y (array): Input array.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >                 alpha (float): Scaling factor for ``x``.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >                 beta (float): Scaling factor for ``y``.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >             Returns:< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >                 array: ``alpha * x + beta * y``< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "s" >         < / span > < span  class = "dl" > )pbdoc< / span > < span  class = "s" > " < / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Most of the complexity in the above example comes from additional bells and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								whistles such as the literal names and doc-strings.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition warning" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Warning< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx.core< / span > < / code >  needs to be imported before importing
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx_sample_extensions< / span > < / code >  as defined by the pybind11 module above to
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								ensure that the casters for < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx.core< / span > < / code >  components like
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< a  class = "reference internal"  href = "../python/_autosummary/mlx.core.array.html#mlx.core.array"  title = "mlx.core.array" > < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > mlx.core.array< / span > < / code > < / a >  are available.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "building-with-cmake" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  id = "id1" > < / span > < h3 > Building with CMake< a  class = "headerlink"  href = "#building-with-cmake"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Building the C++ extension library itself is simple, it only requires that you
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > find_package(MLX< / span >  < span  class = "pre" > CONFIG)< / span > < / code >  and then link it to your library.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-cmake notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "c" > # Add library< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > add_library< / span > < span  class = "p" > (< / span > < span  class = "s" > mlx_ext< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c" > # Add sources< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > target_sources< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > mlx_ext< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > PUBLIC< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "o" > ${< / span > < span  class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span  class = "o" > }< / span > < span  class = "s" > /axpby/axpby.cpp< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c" > # Add include headers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > target_include_directories< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > mlx_ext< / span > < span  class = "w" >  < / span > < span  class = "s" > PUBLIC< / span > < span  class = "w" >  < / span > < span  class = "o" > ${< / span > < span  class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span  class = "o" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c" > # Link to mlx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > target_link_libraries< / span > < span  class = "p" > (< / span > < span  class = "s" > mlx_ext< / span > < span  class = "w" >  < / span > < span  class = "s" > PUBLIC< / span > < span  class = "w" >  < / span > < span  class = "s" > mlx< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We also need to build the attached metal library. For convenience, we provide a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > mlx_build_metallib()< / span > < / code >  function that builds a < code  class = "docutils literal notranslate" > < span  class = "pre" > .metallib< / span > < / code >  target given
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								sources, headers, destinations, etc. (defined in < code  class = "docutils literal notranslate" > < span  class = "pre" > cmake/extension.cmake< / span > < / code >  and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								automatically imported with MLX package).< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Here is what that looks like in practice!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-cmake notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "c" > # Build metallib< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > if< / span > < span  class = "p" > (< / span > < span  class = "s" > MLX_BUILD_METAL< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > mlx_build_metallib< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > TARGET< / span > < span  class = "w" >  < / span > < span  class = "s" > mlx_ext_metallib< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > TITLE< / span > < span  class = "w" >  < / span > < span  class = "s" > mlx_ext< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > SOURCES< / span > < span  class = "w" >  < / span > < span  class = "o" > ${< / span > < span  class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span  class = "o" > }< / span > < span  class = "s" > /axpby/axpby.metal< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > INCLUDE_DIRS< / span > < span  class = "w" >  < / span > < span  class = "o" > ${< / span > < span  class = "nv" > PROJECT_SOURCE_DIR< / span > < span  class = "o" > }< / span > < span  class = "w" >  < / span > < span  class = "o" > ${< / span > < span  class = "nv" > MLX_INCLUDE_DIRS< / span > < span  class = "o" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > OUTPUT_DIRECTORY< / span > < span  class = "w" >  < / span > < span  class = "o" > ${< / span > < span  class = "nv" > CMAKE_LIBRARY_OUTPUT_DIRECTORY< / span > < span  class = "o" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > add_dependencies< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > mlx_ext< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > mlx_ext_metallib< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > endif< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Finally, we build the < a  class = "reference external"  href = "https://pybind11.readthedocs.io/en/stable/" > Pybind11< / a >  bindings< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-cmake notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "nb" > pybind11_add_module< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "s" > mlx_sample_extensions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "o" > ${< / span > < span  class = "nv" > CMAKE_CURRENT_LIST_DIR< / span > < span  class = "o" > }< / span > < span  class = "s" > /bindings.cpp< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > target_link_libraries< / span > < span  class = "p" > (< / span > < span  class = "s" > mlx_sample_extensions< / span > < span  class = "w" >  < / span > < span  class = "s" > PRIVATE< / span > < span  class = "w" >  < / span > < span  class = "s" > mlx_ext< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > if< / span > < span  class = "p" > (< / span > < span  class = "s" > BUILD_SHARED_LIBS< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "w" >     < / span > < span  class = "nb" > target_link_options< / span > < span  class = "p" > (< / span > < span  class = "s" > mlx_sample_extensions< / span > < span  class = "w" >  < / span > < span  class = "s" > PRIVATE< / span > < span  class = "w" >  < / span > < span  class = "s" > -Wl,-rpath,@loader_path< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > endif< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "building-with-setuptools" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Building with < code  class = "docutils literal notranslate" > < span  class = "pre" > setuptools< / span > < / code > < a  class = "headerlink"  href = "#building-with-setuptools"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Once we have set out the CMake build rules as described above, we can use the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								build utilities defined in < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx.extension< / span > < / code >  for a simple build process.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > from< / span >  < span  class = "nn" > mlx< / span >  < span  class = "kn" > import< / span >  < span  class = "n" > extension< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > from< / span >  < span  class = "nn" > setuptools< / span >  < span  class = "kn" > import< / span >  < span  class = "n" > setup< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > if< / span >  < span  class = "vm" > __name__< / span >  < span  class = "o" > ==< / span >  < span  class = "s2" > " __main__" < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > setup< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > name< / span > < span  class = "o" > =< / span > < span  class = "s2" > " mlx_sample_extensions" < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > version< / span > < span  class = "o" > =< / span > < span  class = "s2" > " 0.0.0" < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > description< / span > < span  class = "o" > =< / span > < span  class = "s2" > " Sample C++ and Metal extensions for MLX primitives." < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ext_modules< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "n" > extension< / span > < span  class = "o" > .< / span > < span  class = "n" > CMakeExtension< / span > < span  class = "p" > (< / span > < span  class = "s2" > " mlx_sample_extensions" < / span > < span  class = "p" > )],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > cmdclass< / span > < span  class = "o" > =< / span > < span  class = "p" > {< / span > < span  class = "s2" > " build_ext" < / span > < span  class = "p" > :< / span >  < span  class = "n" > extension< / span > < span  class = "o" > .< / span > < span  class = "n" > CMakeBuild< / span > < span  class = "p" > },< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > packages< / span >  < span  class = "o" > =< / span >  < span  class = "p" > [< / span > < span  class = "s2" > " mlx_sample_extensions" < / span > < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > package_dir< / span >  < span  class = "o" > =< / span >  < span  class = "p" > {< / span > < span  class = "s2" > " " < / span > < span  class = "p" > :< / span >  < span  class = "s2" > " mlx_sample_extensions" < / span > < span  class = "p" > },< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > package_data< / span >  < span  class = "o" > =< / span >  < span  class = "p" > {< / span > < span  class = "s2" > " mlx_sample_extensions" < / span >  < span  class = "p" > :< / span >  < span  class = "p" > [< / span > < span  class = "s2" > " *.so" < / span > < span  class = "p" > ,< / span >  < span  class = "s2" > " *.dylib" < / span > < span  class = "p" > ,< / span >  < span  class = "s2" > " *.metallib" < / span > < span  class = "p" > ]},< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > zip_safe< / span > < span  class = "o" > =< / span > < span  class = "kc" > False< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > python_requires< / span > < span  class = "o" > =< / span > < span  class = "s2" > " > =3.7" < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We treat < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/mlx_sample_extensions< / span > < / code >  as the package directory
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								even though it only contains a < code  class = "docutils literal notranslate" > < span  class = "pre" > __init__.py< / span > < / code >  to ensure the following:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx.core< / span > < / code >  is always imported before importing  < code  class = "xref py py-mod docutils literal notranslate" > < span  class = "pre" > mlx_sample_extensions< / span > < / code > < / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > The C++ extension library and the metal library are co-located with the python
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								bindings and copied together if the package is installed< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > You can build inplace for development using
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > python< / span >  < span  class = "pre" > setup.py< / span >  < span  class = "pre" > build_ext< / span >  < span  class = "pre" > -j8< / span >  < span  class = "pre" > --inplace< / span > < / code >  (in < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/< / span > < / code > )< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This will result in a directory structure as follows:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line-block" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > extensions< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > ├── mlx_sample_extensions< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   ├── __init__.py< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   ├── libmlx_ext.dylib # C++ extension library< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   ├── mlx_ext.metallib # Metal library< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > │   └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "line" > …< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > When you try to install using the command < code  class = "docutils literal notranslate" > < span  class = "pre" > python< / span >  < span  class = "pre" > -m< / span >  < span  class = "pre" > pip< / span >  < span  class = "pre" > install< / span >  < span  class = "pre" > .< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								(in < code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/< / span > < / code > ), the package will be installed with the same strucutre as
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "docutils literal notranslate" > < span  class = "pre" > extensions/mlx_sample_extensions< / span > < / code >  and the C++ and metal library will be
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								copied along with the python binding since they are specified as < code  class = "docutils literal notranslate" > < span  class = "pre" > package_data< / span > < / code > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "usage" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Usage< a  class = "headerlink"  href = "#usage"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > After installing the extension as described above, you should be able to simply
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								import the python package and play with it as you would any other MLX operation!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s looks at a simple script and it’  s results!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > mlx.core< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > mx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > from< / span >  < span  class = "nn" > mlx_sample_extensions< / span >  < span  class = "kn" > import< / span >  < span  class = "n" > axpby< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > a< / span >  < span  class = "o" > =< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > ones< / span > < span  class = "p" > ((< / span > < span  class = "mi" > 3< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 4< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > ones< / span > < span  class = "p" > ((< / span > < span  class = "mi" > 3< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 4< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > c< / span >  < span  class = "o" > =< / span >  < span  class = "n" > axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 4.0< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 2.0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stream< / span > < span  class = "o" > =< / span > < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > cpu< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " c shape: < / span > < span  class = "si" > {< / span > < span  class = "n" > c< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "si" > }< / span > < span  class = "s2" > " < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " c dtype: < / span > < span  class = "si" > {< / span > < span  class = "n" > c< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "si" > }< / span > < span  class = "s2" > " < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " c correctness: < / span > < span  class = "si" > {< / span > < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > all< / span > < span  class = "p" > (< / span > < span  class = "n" > c< / span > < span  class = "w" >  < / span > < span  class = "o" > ==< / span > < span  class = "w" >  < / span > < span  class = "mf" > 6.0< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > item< / span > < span  class = "p" > ()< / span > < span  class = "si" > }< / span > < span  class = "s2" > " < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Output:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > c< / span >  < span  class = "n" > shape< / span > < span  class = "p" > :< / span >  < span  class = "p" > [< / span > < span  class = "mi" > 3< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 4< / span > < span  class = "p" > ]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > c< / span >  < span  class = "n" > dtype< / span > < span  class = "p" > :< / span >  < span  class = "n" > float32< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > c< / span >  < span  class = "n" > correctness< / span > < span  class = "p" > :< / span >  < span  class = "kc" > True< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "results" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Results< a  class = "headerlink"  href = "#results"  title = "Permalink to this heading" > #< / a > < / h3 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s run a quick benchmark and see how our new < code  class = "docutils literal notranslate" > < span  class = "pre" > axpby< / span > < / code >  operation compares
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								with the naive < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > simple_axpby()< / span > < / code >  we defined at first on the CPU.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > mlx.core< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > mx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > from< / span >  < span  class = "nn" > mlx_sample_extensions< / span >  < span  class = "kn" > import< / span >  < span  class = "n" > axpby< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > time< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > set_default_device< / span > < span  class = "p" > (< / span > < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > cpu< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > simple_axpby< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > :< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > :< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > ,< / span >  < span  class = "n" > alpha< / span > < span  class = "p" > :< / span >  < span  class = "nb" > float< / span > < span  class = "p" > ,< / span >  < span  class = "n" > beta< / span > < span  class = "p" > :< / span >  < span  class = "nb" > float< / span > < span  class = "p" > )< / span >  < span  class = "o" > -> < / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > array< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > alpha< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span >  < span  class = "o" > +< / span >  < span  class = "n" > beta< / span >  < span  class = "o" > *< / span >  < span  class = "n" > y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > M< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 256< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > N< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 512< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > random< / span > < span  class = "o" > .< / span > < span  class = "n" > normal< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > y< / span >  < span  class = "o" > =< / span >  < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > random< / span > < span  class = "o" > .< / span > < span  class = "n" > normal< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > alpha< / span >  < span  class = "o" > =< / span >  < span  class = "mf" > 4.0< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > beta< / span >  < span  class = "o" > =< / span >  < span  class = "mf" > 2.0< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > eval< / span > < span  class = "p" > ((< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > bench< / span > < span  class = "p" > (< / span > < span  class = "n" > f< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Warm up< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > for< / span >  < span  class = "n" > i< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 100< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > z< / span >  < span  class = "o" > =< / span >  < span  class = "n" > f< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > ,< / span >  < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span >  < span  class = "n" > beta< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > eval< / span > < span  class = "p" > (< / span > < span  class = "n" > z< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Timed run< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > s< / span >  < span  class = "o" > =< / span >  < span  class = "n" > time< / span > < span  class = "o" > .< / span > < span  class = "n" > time< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > for< / span >  < span  class = "n" > i< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 5000< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > z< / span >  < span  class = "o" > =< / span >  < span  class = "n" > f< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > ,< / span >  < span  class = "n" > alpha< / span > < span  class = "p" > ,< / span >  < span  class = "n" > beta< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > mx< / span > < span  class = "o" > .< / span > < span  class = "n" > eval< / span > < span  class = "p" > (< / span > < span  class = "n" > z< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > e< / span >  < span  class = "o" > =< / span >  < span  class = "n" > time< / span > < span  class = "o" > .< / span > < span  class = "n" > time< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > e< / span >  < span  class = "o" > -< / span >  < span  class = "n" > s< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > simple_time< / span >  < span  class = "o" > =< / span >  < span  class = "n" > bench< / span > < span  class = "p" > (< / span > < span  class = "n" > simple_axpby< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > custom_time< / span >  < span  class = "o" > =< / span >  < span  class = "n" > bench< / span > < span  class = "p" > (< / span > < span  class = "n" > axpby< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " Simple axpby: < / span > < span  class = "si" > {< / span > < span  class = "n" > simple_time< / span > < span  class = "si" > :< / span > < span  class = "s2" > .3f< / span > < span  class = "si" > }< / span > < span  class = "s2" >  s | Custom axpby: < / span > < span  class = "si" > {< / span > < span  class = "n" > custom_time< / span > < span  class = "si" > :< / span > < span  class = "s2" > .3f< / span > < span  class = "si" > }< / span > < span  class = "s2" >  s" < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Results:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > Simple< / span >  < span  class = "n" > axpby< / span > < span  class = "p" > :< / span >  < span  class = "mf" > 0.114< / span >  < span  class = "n" > s< / span >  < span  class = "o" > |< / span >  < span  class = "n" > Custom< / span >  < span  class = "n" > axpby< / span > < span  class = "p" > :< / span >  < span  class = "mf" > 0.109< / span >  < span  class = "n" > s< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We see some modest improvements right away!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > This operation is now good to be used to build other operations,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-17 13:23:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								in < a  class = "reference internal"  href = "../python/_autosummary/mlx.nn.Module.html#mlx.nn.Module"  title = "mlx.nn.Module" > < code  class = "xref py py-class docutils literal notranslate" > < span  class = "pre" > mlx.nn.Module< / span > < / code > < / a >  calls, and also as a part of graph
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								transformations such as < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > grad()< / span > < / code >  and < code  class = "xref py py-meth docutils literal notranslate" > < span  class = "pre" > simplify()< / span > < / code > !< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< section  id = "scripts" > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Scripts< a  class = "headerlink"  href = "#scripts"  title = "Permalink to this heading" > #< / a > < / h2 > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition-download-the-code admonition" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Download the code< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > The full example code is available in < a  class = "reference external"  href = "code" > mlx-examples< / a > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < / article > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < footer  class = "prev-next-footer" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "prev-next-area" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < a  class = "left-prev" 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								       href="../cpp/ops.html"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								       title="previous page">
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < i  class = "fa-solid fa-angle-left" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "prev-next-info" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < p  class = "prev-next-subtitle" > previous< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < p  class = "prev-next-title" > Operations< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								                < div  class = "bd-sidebar-secondary bd-toc" > < div  class = "sidebar-secondary-items sidebar-secondary__inner" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "sidebar-secondary-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "page-toc tocsection onthispage" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < i  class = "fa-solid fa-list" > < / i >  Contents
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < nav  class = "bd-toc-nav page-toc" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#introducing-the-example" > Introducing the Example< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#operations-and-primitives" > Operations and Primitives< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#operations" > Operations< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#primitives" > Primitives< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#using-the-primitives" > Using the Primitives< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-primitive" > Implementing the Primitive< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-cpu-backend" > Implementing the CPU Backend< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#implementing-the-gpu-backend" > Implementing the GPU Backend< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#primitive-transforms" > Primitive Transforms< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-and-binding" > Building and Binding< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#binding-to-python" > Binding to Python< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-with-cmake" > Building with CMake< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#building-with-setuptools" > Building with < code  class = "docutils literal notranslate" > < span  class = "pre" > setuptools< / span > < / code > < / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#usage" > Usage< / a > < ul  class = "visible nav section-nav flex-column" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h3 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#results" > Results< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toc-h2 nav-item toc-entry" > < a  class = "reference internal nav-link"  href = "#scripts" > Scripts< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / nav > < / div > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < footer  class = "bd-footer-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "bd-footer-content__inner container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "footer-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "component-author" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								By MLX Contributors
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / p > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "footer-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < p  class = "copyright" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      © Copyright 2023, MLX Contributors.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < br / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / p > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "footer-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "footer-item" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / main > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <!--  Scripts loaded after <body> so the DOM is not blocked  --> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < script  src = "../_static/scripts/bootstrap.js?digest=5b4479735964841361fd" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< script  src = "../_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" > < / script > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-12-05 12:10:03 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < footer  class = "bd-footer" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / body > 
							 
						 
					
						
							
								
									
										
										
										
											2023-11-29 12:41:56 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / html >