mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
fix copies in sdpa (#2563)
This commit is contained in:
@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
// Define some copy functions to ensure the layout of the inputs is as
|
// Define some copy functions to ensure the layout of the inputs is as
|
||||||
// expected.
|
// expected.
|
||||||
copies.reserve(3);
|
copies.reserve(inputs.size());
|
||||||
auto copy_unless = [&copies, &s](
|
auto copy_unless = [&copies, &s](
|
||||||
auto predicate, const array& arr) -> const array& {
|
auto predicate, const array& arr) -> const array& {
|
||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
|
@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_sdpa_noncontiguous_inputs(self):
|
||||||
|
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
|
||||||
|
|
||||||
|
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||||
|
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
def test_sdpa_promote_mask(self):
|
def test_sdpa_promote_mask(self):
|
||||||
mask = mx.array(2.0, mx.bfloat16)
|
mask = mx.array(2.0, mx.bfloat16)
|
||||||
D = 64
|
D = 64
|
||||||
|
Reference in New Issue
Block a user