enable complex gemm (#2017)

This commit is contained in:
Awni Hannun
2025-03-28 10:45:13 -07:00
committed by GitHub
parent 5580b47291
commit 98b901ad66
3 changed files with 71 additions and 1 deletions

View File

@@ -2826,6 +2826,19 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
// Complex matmul in terms of real matmuls
if (out_type == complex64) {
auto a_real = real(a, s);
auto b_real = real(b, s);
auto a_imag = imag(a, s);
auto b_imag = imag(b, s);
auto c_real =
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
return add(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
@@ -4150,6 +4163,14 @@ array addmm(
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "