mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
fix copies in sdpa (#2563)
This commit is contained in:
@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
|
Reference in New Issue
Block a user