mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* Remove "using namespace mlx::core" in benchmarks/examples * Fix building example extension * A missing one in comment * Fix building on M chips
		
			
				
	
	
		
			40 lines
		
	
	
		
			880 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			880 B
		
	
	
	
		
			C++
		
	
	
	
	
	
// Copyright © 2023-2024 Apple Inc.
 | 
						|
 | 
						|
#include <nanobind/nanobind.h>
 | 
						|
#include <nanobind/stl/variant.h>
 | 
						|
 | 
						|
#include "axpby/axpby.h"
 | 
						|
 | 
						|
namespace nb = nanobind;
 | 
						|
using namespace nb::literals;
 | 
						|
 | 
						|
NB_MODULE(_ext, m) {
 | 
						|
  m.doc() = "Sample extension for MLX";
 | 
						|
 | 
						|
  m.def(
 | 
						|
      "axpby",
 | 
						|
      &my_ext::axpby,
 | 
						|
      "x"_a,
 | 
						|
      "y"_a,
 | 
						|
      "alpha"_a,
 | 
						|
      "beta"_a,
 | 
						|
      nb::kw_only(),
 | 
						|
      "stream"_a = nb::none(),
 | 
						|
      R"(
 | 
						|
        Scale and sum two vectors element-wise
 | 
						|
        ``z = alpha * x + beta * y``
 | 
						|
 | 
						|
        Follows numpy style broadcasting between ``x`` and ``y``
 | 
						|
        Inputs are upcasted to floats if needed
 | 
						|
 | 
						|
        Args:
 | 
						|
            x (array): Input array.
 | 
						|
            y (array): Input array.
 | 
						|
            alpha (float): Scaling factor for ``x``.
 | 
						|
            beta (float): Scaling factor for ``y``.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            array: ``alpha * x + beta * y``
 | 
						|
      )");
 | 
						|
}
 |