mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Update attention tests to show diff, disable array masks (#1978)
This commit is contained in:
parent
9adcd1a650
commit
b42d13ec84
@ -748,8 +748,8 @@ array scaled_dot_product_attention(
|
|||||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
const bool sdpa_full_supported_mask =
|
||||||
(query_sequence_length <= key_sequence_length && do_causal);
|
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||||
|
@ -72,8 +72,8 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
|||||||
|
|
||||||
scale = 1.0 / math.sqrt(D)
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype)
|
||||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype)
|
||||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
q_mx = mx.array(q_np)
|
q_mx = mx.array(q_np)
|
||||||
@ -524,9 +524,8 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
list(out_ref.shape), list(out_fst.shape)
|
list(out_ref.shape), list(out_fst.shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
|
||||||
mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
|
self.assertLessEqual(mx.max(diff).item(), atol)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user