mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa * more permissive donation in ternary
This commit is contained in:
@@ -67,7 +67,12 @@ void set_ternary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -659,7 +659,12 @@ array scaled_dot_product_attention(
|
||||
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
}
|
||||
scores = add(scores, mask, s);
|
||||
if (mask.dtype() == bool_) {
|
||||
scores = where(
|
||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
||||
} else {
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
|
Reference in New Issue
Block a user