mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
enable complex gemm (#2017)
This commit is contained in:
21
mlx/ops.cpp
21
mlx/ops.cpp
@@ -2826,6 +2826,19 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
// Complex matmul in terms of real matmuls
|
||||
if (out_type == complex64) {
|
||||
auto a_real = real(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto a_imag = imag(a, s);
|
||||
auto b_imag = imag(b, s);
|
||||
auto c_real =
|
||||
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
||||
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
||||
return add(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
@@ -4150,6 +4163,14 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
|
||||
if (out_type == complex64) {
|
||||
return add(
|
||||
multiply(matmul(a, b, s), array(alpha), s),
|
||||
multiply(array(beta), c, s),
|
||||
s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
|
||||
Reference in New Issue
Block a user