Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun
2025-01-06 16:57:07 -08:00
committed by GitHub
parent 25b3a3e541
commit d5ec172c95
4 changed files with 28 additions and 5 deletions

View File

@@ -67,7 +67,12 @@ void set_ternary_op_output_data(
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}
}

View File

@@ -659,7 +659,12 @@ array scaled_dot_product_attention(
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
}
}
scores = add(scores, mask, s);
if (mask.dtype() == bool_) {
scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
} else {
scores = add(scores, mask, s);
}
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);