mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	 4f9b60dd53
			
		
	
	4f9b60dd53
	
	
	
		
			
			* Remove "using namespace mlx::core" in benchmarks/examples * Fix building example extension * A missing one in comment * Fix building on M chips
		
			
				
	
	
		
			40 lines
		
	
	
		
			880 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			880 B
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| #include <nanobind/nanobind.h>
 | |
| #include <nanobind/stl/variant.h>
 | |
| 
 | |
| #include "axpby/axpby.h"
 | |
| 
 | |
| namespace nb = nanobind;
 | |
| using namespace nb::literals;
 | |
| 
 | |
| NB_MODULE(_ext, m) {
 | |
|   m.doc() = "Sample extension for MLX";
 | |
| 
 | |
|   m.def(
 | |
|       "axpby",
 | |
|       &my_ext::axpby,
 | |
|       "x"_a,
 | |
|       "y"_a,
 | |
|       "alpha"_a,
 | |
|       "beta"_a,
 | |
|       nb::kw_only(),
 | |
|       "stream"_a = nb::none(),
 | |
|       R"(
 | |
|         Scale and sum two vectors element-wise
 | |
|         ``z = alpha * x + beta * y``
 | |
| 
 | |
|         Follows numpy style broadcasting between ``x`` and ``y``
 | |
|         Inputs are upcasted to floats if needed
 | |
| 
 | |
|         Args:
 | |
|             x (array): Input array.
 | |
|             y (array): Input array.
 | |
|             alpha (float): Scaling factor for ``x``.
 | |
|             beta (float): Scaling factor for ``y``.
 | |
| 
 | |
|         Returns:
 | |
|             array: ``alpha * x + beta * y``
 | |
|       )");
 | |
| }
 |