diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 02d0398bb..895ae2d52 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -533,6 +533,12 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } } + if (mask and (*mask).ndim() > 4) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] the mask with shape " + << (*mask).shape() << " expected to have at most rank 4"; + throw std::invalid_argument(msg.str()); + } const size_t batch_dim = queries.shape(0); for (const auto& tensor : {keys, values}) { @@ -599,8 +605,7 @@ array scaled_dot_product_attention( threshold = std::max(1, memory_efficient_threshold.value()); } - bool needs_mask = mask.has_value(); - auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( + auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -614,8 +619,12 @@ array scaled_dot_product_attention( v = expand_dims(v, 2, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); - if (needs_mask) { - scores = add(scores, inputs[3], s); + if (inputs.size() > 3) { + auto mask_shape = inputs[0].shape(); + mask_shape.back() = inputs[1].shape(-2); + auto mask = reshape( + broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s); + scores = add(scores, mask, s); } scores = softmax(scores, std::vector{-1}, true, s); auto out = matmul(scores, v, s); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 829301ab5..b71baa183 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -140,12 +140,23 @@ void init_fast(nb::module_& parent_module) { Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` and ``v`` inputs should not be pre-tiled to match ``q``. + In the following the dimensions are given by: + + * ``B``: The batch size. + * ``N_q``: The number of query heads. + * ``N_kv``: The number of key and value heads. + * ``T_q``: The number of queries per example. + * ``T_kv``: The number of keys and values per example. + * ``D``: The per-head dimension. + Args: - q (array): Input query array. - k (array): Input keys array. - v (array): Input values array. + q (array): Queries with shape ``[B, N_q, T_q, D]``. + k (array): Keys with shape ``[B, N_kv, T_kv, D]``. + v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (array, optional): An additive mask to apply to the query-key scores. + mask (array, optional): An additive mask to apply to the query-key + scores. The mask can have at most 4 dimensions and must be + broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index c736abe93..1df48bc7f 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -7,14 +7,16 @@ import numpy as np # SDPA for MHA (n_heads == n_kv_heads) -def mlx_primitives_sdpa(q, k, v, scale): +def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) + if mask is not None: + p += mask scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) return scores @ v # SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0) -def mlx_primitives_sdpa_with_gqa(q, k, v, scale): +def mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=None): n_repeats = q.shape[1] // k.shape[1] # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py @@ -28,7 +30,7 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale): k, v = map(repeat, (k, v)) - return mlx_primitives_sdpa(q, k, v, scale) + return mlx_primitives_sdpa(q, k, v, scale, mask=mask) class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): @@ -176,6 +178,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase): y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + # Test with per-example mask + q = mx.random.normal(shape=(2, 8, 4, 32)) + k = mx.random.normal(shape=(2, 2, 8, 32)) + v = mx.random.normal(shape=(2, 2, 8, 32)) + mask = 10 * mx.random.normal(shape=(2, 1, 4, 8)) + y = mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=mask) + y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + if __name__ == "__main__": unittest.main(failfast=True)