mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
		@@ -3476,4 +3476,34 @@ void init_ops(py::module_& m) {
 | 
			
		||||
      Returns:
 | 
			
		||||
        result (array): The tiled array.
 | 
			
		||||
    )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "addmm",
 | 
			
		||||
      &addmm,
 | 
			
		||||
      "c"_a,
 | 
			
		||||
      "a"_a,
 | 
			
		||||
      "b"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "alpha"_a = 1.0f,
 | 
			
		||||
      "beta"_a = 1.0f,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0,  *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
 | 
			
		||||
        Matrix multiplication with addition and optional scaling.
 | 
			
		||||
 | 
			
		||||
        Perform the (possibly batched) matrix multiplication of two arrays and add to the result
 | 
			
		||||
        with optional scaling factors.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            c (array): Input array or scalar.
 | 
			
		||||
            a (array): Input array or scalar.
 | 
			
		||||
            b (array): Input array or scalar.
 | 
			
		||||
            alpha (float, optional): Scaling factor for the 
 | 
			
		||||
                matrix product of ``a`` and ``b`` (default: ``1``)
 | 
			
		||||
            beta (float, optional): Scaling factor for ``c`` (default: ``1``)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: ``alpha * (a @ b)  + beta * c``
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user