mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix cudnn routing
This commit is contained in:
@@ -1068,8 +1068,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
flags);
|
||||
}
|
||||
|
||||
// return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
|
Reference in New Issue
Block a user