diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 540fbce02..9f435f73c 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -359,7 +359,7 @@ void steel_matmul_regular_axpby( float beta /* = 0.0f */) { #ifdef MLX_ENABLE_NAX - if (metal::is_nax_available() && + if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && (a.dtype() != float32 || env::enable_tf32())) { return steel_matmul_regular_axpby_nax( /* const Stream& s = */ s,