mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Fix cudnn routing
This commit is contained in:
@@ -1068,8 +1068,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
flags);
|
flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
// return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode
|
// Full attention mode
|
||||||
|
|||||||
Reference in New Issue
Block a user