mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix routing for complex
This commit is contained in:
@@ -359,7 +359,7 @@ void steel_matmul_regular_axpby(
|
||||
float beta /* = 0.0f */) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (metal::is_nax_available() &&
|
||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||
(a.dtype() != float32 || env::enable_tf32())) {
|
||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||
/* const Stream& s = */ s,
|
||||
|
||||
Reference in New Issue
Block a user