mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
enable complex gemm (#2017)
This commit is contained in:
@@ -341,7 +341,7 @@ struct GEMVTKernel {
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
auto vc = float(v_coeff[tm]);
|
||||
auto vc = static_cast<AccT>(v_coeff[tm]);
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
21
mlx/ops.cpp
21
mlx/ops.cpp
@@ -2826,6 +2826,19 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
// Complex matmul in terms of real matmuls
|
||||
if (out_type == complex64) {
|
||||
auto a_real = real(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto a_imag = imag(a, s);
|
||||
auto b_imag = imag(b, s);
|
||||
auto c_real =
|
||||
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
||||
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
||||
return add(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
@@ -4150,6 +4163,14 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
|
||||
if (out_type == complex64) {
|
||||
return add(
|
||||
multiply(matmul(a, b, s), array(alpha), s),
|
||||
multiply(array(beta), c, s),
|
||||
s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
|
||||
Reference in New Issue
Block a user