Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun
2025-01-06 16:57:07 -08:00
committed by GitHub
parent 25b3a3e541
commit d5ec172c95
4 changed files with 28 additions and 5 deletions

View File

@@ -67,7 +67,12 @@ void set_ternary_op_output_data(
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}
}