mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
enable complex gemm (#2017)
This commit is contained in:
parent
5580b47291
commit
98b901ad66
@ -341,7 +341,7 @@ struct GEMVTKernel {
|
|||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
auto vc = float(v_coeff[tm]);
|
auto vc = static_cast<AccT>(v_coeff[tm]);
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
}
|
}
|
||||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -2826,6 +2826,19 @@ array matmul(
|
|||||||
}
|
}
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
|
// Complex matmul in terms of real matmuls
|
||||||
|
if (out_type == complex64) {
|
||||||
|
auto a_real = real(a, s);
|
||||||
|
auto b_real = real(b, s);
|
||||||
|
auto a_imag = imag(a, s);
|
||||||
|
auto b_imag = imag(b, s);
|
||||||
|
auto c_real =
|
||||||
|
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
||||||
|
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
||||||
|
return add(
|
||||||
|
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[matmul] Only real floating point types are supported but "
|
msg << "[matmul] Only real floating point types are supported but "
|
||||||
@ -4150,6 +4163,14 @@ array addmm(
|
|||||||
|
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = result_type(a, b, c);
|
auto out_type = result_type(a, b, c);
|
||||||
|
|
||||||
|
if (out_type == complex64) {
|
||||||
|
return add(
|
||||||
|
multiply(matmul(a, b, s), array(alpha), s),
|
||||||
|
multiply(array(beta), c, s),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[addmm] Only real floating point types are supported but "
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
|
@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
out_gemm = (b @ c)[0]
|
out_gemm = (b @ c)[0]
|
||||||
self.assertTrue(mx.allclose(out_gemv, out_gemm))
|
self.assertTrue(mx.allclose(out_gemv, out_gemm))
|
||||||
|
|
||||||
|
def test_complex_gemv(self):
|
||||||
|
M = 16
|
||||||
|
N = 50
|
||||||
|
|
||||||
|
def rand(shape):
|
||||||
|
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||||
|
|
||||||
|
a = rand((M, N))
|
||||||
|
b = rand((N, 1))
|
||||||
|
c = mx.matmul(a, b)
|
||||||
|
c_np = np.matmul(a, b)
|
||||||
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
|
# Transposed
|
||||||
|
a = rand((N, M))
|
||||||
|
b = rand((N, 1))
|
||||||
|
c = mx.matmul(a.T, b)
|
||||||
|
c_np = np.matmul(np.array(a).T, b)
|
||||||
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
|
def test_complex_gemm(self):
|
||||||
|
M = 16
|
||||||
|
K = 50
|
||||||
|
N = 32
|
||||||
|
|
||||||
|
def rand(shape):
|
||||||
|
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||||
|
|
||||||
|
a = rand((M, K))
|
||||||
|
b = rand((K, N))
|
||||||
|
c = mx.matmul(a, b)
|
||||||
|
c_np = np.matmul(a, b)
|
||||||
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
|
# Test addmm
|
||||||
|
M = 16
|
||||||
|
K = 50
|
||||||
|
N = 32
|
||||||
|
|
||||||
|
def rand(shape):
|
||||||
|
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||||
|
|
||||||
|
a = rand((M, K))
|
||||||
|
b = rand((K, N))
|
||||||
|
c = rand((M, N))
|
||||||
|
out = mx.addmm(c, a, b, 2.0, 2.0)
|
||||||
|
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
||||||
|
self.assertTrue(np.allclose(out, out_np))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user