fix complex matmul shape

This commit is contained in:
Awni Hannun 2025-06-13 07:42:12 -07:00
parent 4a01d2d5d3
commit 628e36f7d9
3 changed files with 37 additions and 26 deletions

View File

@ -196,7 +196,7 @@ struct Power {
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<T>::quiet_NaN();
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);

View File

@ -2847,21 +2847,6 @@ array matmul(
"[matmul] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// complex matmul using Karatsuba's Algorithm
if (a.dtype() == complex64 || b.dtype() == complex64) {
@ -2883,6 +2868,22 @@ array matmul(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
@ -4240,6 +4241,16 @@ array addmm(
"have at least one dimension.");
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
@ -4257,16 +4268,6 @@ array addmm(
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "

View File

@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np))
# Check shapes
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
self.assertEqual((a @ b).shape, (2,))
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
c = mx.random.normal((2,))
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
def test_complex_gemm(self):
M = 16
K = 50