mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	enable complex gemm (#2017)
This commit is contained in:
		| @@ -341,7 +341,7 @@ struct GEMVTKernel { | ||||
|  | ||||
|         MLX_MTL_PRAGMA_UNROLL | ||||
|         for (int tm = 0; tm < TM; tm++) { | ||||
|           auto vc = float(v_coeff[tm]); | ||||
|           auto vc = static_cast<AccT>(v_coeff[tm]); | ||||
|           for (int tn = 0; tn < TN; tn++) { | ||||
|             inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; | ||||
|           } | ||||
|   | ||||
							
								
								
									
										21
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							| @@ -2826,6 +2826,19 @@ array matmul( | ||||
|   } | ||||
|   // Type promotion | ||||
|   auto out_type = promote_types(a.dtype(), b.dtype()); | ||||
|   // Complex matmul in terms of real matmuls | ||||
|   if (out_type == complex64) { | ||||
|     auto a_real = real(a, s); | ||||
|     auto b_real = real(b, s); | ||||
|     auto a_imag = imag(a, s); | ||||
|     auto b_imag = imag(b, s); | ||||
|     auto c_real = | ||||
|         subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); | ||||
|     auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); | ||||
|     return add( | ||||
|         c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); | ||||
|   } | ||||
|  | ||||
|   if (!issubdtype(out_type, floating)) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[matmul] Only real floating point types are supported but " | ||||
| @@ -4150,6 +4163,14 @@ array addmm( | ||||
|  | ||||
|   // Type promotion | ||||
|   auto out_type = result_type(a, b, c); | ||||
|  | ||||
|   if (out_type == complex64) { | ||||
|     return add( | ||||
|         multiply(matmul(a, b, s), array(alpha), s), | ||||
|         multiply(array(beta), c, s), | ||||
|         s); | ||||
|   } | ||||
|  | ||||
|   if (!issubdtype(out_type, floating)) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[addmm] Only real floating point types are supported but " | ||||
|   | ||||
| @@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|             out_gemm = (b @ c)[0] | ||||
|             self.assertTrue(mx.allclose(out_gemv, out_gemm)) | ||||
|  | ||||
|     def test_complex_gemv(self): | ||||
|         M = 16 | ||||
|         N = 50 | ||||
|  | ||||
|         def rand(shape): | ||||
|             return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) | ||||
|  | ||||
|         a = rand((M, N)) | ||||
|         b = rand((N, 1)) | ||||
|         c = mx.matmul(a, b) | ||||
|         c_np = np.matmul(a, b) | ||||
|         self.assertTrue(np.allclose(c, c_np)) | ||||
|  | ||||
|         # Transposed | ||||
|         a = rand((N, M)) | ||||
|         b = rand((N, 1)) | ||||
|         c = mx.matmul(a.T, b) | ||||
|         c_np = np.matmul(np.array(a).T, b) | ||||
|         self.assertTrue(np.allclose(c, c_np)) | ||||
|  | ||||
|     def test_complex_gemm(self): | ||||
|         M = 16 | ||||
|         K = 50 | ||||
|         N = 32 | ||||
|  | ||||
|         def rand(shape): | ||||
|             return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) | ||||
|  | ||||
|         a = rand((M, K)) | ||||
|         b = rand((K, N)) | ||||
|         c = mx.matmul(a, b) | ||||
|         c_np = np.matmul(a, b) | ||||
|         self.assertTrue(np.allclose(c, c_np)) | ||||
|  | ||||
|         # Test addmm | ||||
|         M = 16 | ||||
|         K = 50 | ||||
|         N = 32 | ||||
|  | ||||
|         def rand(shape): | ||||
|             return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) | ||||
|  | ||||
|         a = rand((M, K)) | ||||
|         b = rand((K, N)) | ||||
|         c = rand((M, N)) | ||||
|         out = mx.addmm(c, a, b, 2.0, 2.0) | ||||
|         out_np = 2.0 * np.matmul(a, b) + 2.0 * c | ||||
|         self.assertTrue(np.allclose(out, out_np)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun