mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix routing for complex
This commit is contained in:
@@ -359,7 +359,7 @@ void steel_matmul_regular_axpby(
|
|||||||
float beta /* = 0.0f */) {
|
float beta /* = 0.0f */) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
if (metal::is_nax_available() &&
|
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||||
(a.dtype() != float32 || env::enable_tf32())) {
|
(a.dtype() != float32 || env::enable_tf32())) {
|
||||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||||
/* const Stream& s = */ s,
|
/* const Stream& s = */ s,
|
||||||
|
|||||||
Reference in New Issue
Block a user