mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Gather mm new kernel and small refactoring (#2040)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							e9e268336b
						
					
				
				
					commit
					99eefd2ec0
				
			| @@ -4464,9 +4464,10 @@ void init_ops(nb::module_& m) { | ||||
|       "lhs_indices"_a = nb::none(), | ||||
|       "rhs_indices"_a = nb::none(), | ||||
|       nb::kw_only(), | ||||
|       "sorted_indices"_a = false, | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|           "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Matrix multiplication with matrix-level gather. | ||||
|  | ||||
| @@ -4485,11 +4486,16 @@ void init_ops(nb::module_& m) { | ||||
|         For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` | ||||
|         contains indices from the range ``[0, B1 * B2 * ... * BS)`` | ||||
|  | ||||
|         If only one index is passed and it is sorted, the ``sorted_indices`` | ||||
|         flag can be passed for a possible faster implementation. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             b (array): Input array. | ||||
|             lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` | ||||
|             rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` | ||||
|             sorted_indices (bool, optional): May allow a faster implementation | ||||
|               if the passed indices are sorted. Default: ``False``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The output array. | ||||
|   | ||||
| @@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|             lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) | ||||
|             rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) | ||||
|             M = a.shape[-2] | ||||
|             N = b.shape[-2] | ||||
|             N = b.shape[-1] | ||||
|             K = a.shape[-1] | ||||
|  | ||||
|             a = a.reshape((-1, M, K)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user