Fix routing for complex

This commit is contained in:
Jagrit Digani
2025-11-14 16:03:34 -08:00
parent 210400d112
commit a0558cd3af

View File

@@ -359,7 +359,7 @@ void steel_matmul_regular_axpby(
float beta /* = 0.0f */) { float beta /* = 0.0f */) {
#ifdef MLX_ENABLE_NAX #ifdef MLX_ENABLE_NAX
if (metal::is_nax_available() && if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(a.dtype() != float32 || env::enable_tf32())) { (a.dtype() != float32 || env::enable_tf32())) {
return steel_matmul_regular_axpby_nax<CHECK_AB>( return steel_matmul_regular_axpby_nax<CHECK_AB>(
/* const Stream& s = */ s, /* const Stream& s = */ s,