enable complex gemm (#2017)

This commit is contained in:
Awni Hannun 2025-03-28 10:45:13 -07:00 committed by GitHub
parent 5580b47291
commit 98b901ad66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 1 deletions

View File

@ -341,7 +341,7 @@ struct GEMVTKernel {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
auto vc = float(v_coeff[tm]);
auto vc = static_cast<AccT>(v_coeff[tm]);
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}

View File

@ -2826,6 +2826,19 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
// Complex matmul in terms of real matmuls
if (out_type == complex64) {
auto a_real = real(a, s);
auto b_real = real(b, s);
auto a_imag = imag(a, s);
auto b_imag = imag(b, s);
auto c_real =
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
return add(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
@ -4150,6 +4163,14 @@ array addmm(
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "

View File

@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase):
out_gemm = (b @ c)[0]
self.assertTrue(mx.allclose(out_gemv, out_gemm))
def test_complex_gemv(self):
M = 16
N = 50
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, N))
b = rand((N, 1))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(c, c_np))
# Transposed
a = rand((N, M))
b = rand((N, 1))
c = mx.matmul(a.T, b)
c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np))
def test_complex_gemm(self):
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(c, c_np))
# Test addmm
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = rand((M, N))
out = mx.addmm(c, a, b, 2.0, 2.0)
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
self.assertTrue(np.allclose(out, out_np))
if __name__ == "__main__":
unittest.main()