mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
enable complex gemm (#2017)
This commit is contained in:
parent
5580b47291
commit
98b901ad66
@ -341,7 +341,7 @@ struct GEMVTKernel {
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
auto vc = float(v_coeff[tm]);
|
||||
auto vc = static_cast<AccT>(v_coeff[tm]);
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -2826,6 +2826,19 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
// Complex matmul in terms of real matmuls
|
||||
if (out_type == complex64) {
|
||||
auto a_real = real(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto a_imag = imag(a, s);
|
||||
auto b_imag = imag(b, s);
|
||||
auto c_real =
|
||||
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
||||
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
||||
return add(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
@ -4150,6 +4163,14 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
|
||||
if (out_type == complex64) {
|
||||
return add(
|
||||
multiply(matmul(a, b, s), array(alpha), s),
|
||||
multiply(array(beta), c, s),
|
||||
s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
|
@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_gemm = (b @ c)[0]
|
||||
self.assertTrue(mx.allclose(out_gemv, out_gemm))
|
||||
|
||||
def test_complex_gemv(self):
|
||||
M = 16
|
||||
N = 50
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, N))
|
||||
b = rand((N, 1))
|
||||
c = mx.matmul(a, b)
|
||||
c_np = np.matmul(a, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Transposed
|
||||
a = rand((N, M))
|
||||
b = rand((N, 1))
|
||||
c = mx.matmul(a.T, b)
|
||||
c_np = np.matmul(np.array(a).T, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
def test_complex_gemm(self):
|
||||
M = 16
|
||||
K = 50
|
||||
N = 32
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, K))
|
||||
b = rand((K, N))
|
||||
c = mx.matmul(a, b)
|
||||
c_np = np.matmul(a, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Test addmm
|
||||
M = 16
|
||||
K = 50
|
||||
N = 32
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, K))
|
||||
b = rand((K, N))
|
||||
c = rand((M, N))
|
||||
out = mx.addmm(c, a, b, 2.0, 2.0)
|
||||
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user