Update attention tests to show diff, disable array masks (#1978)

This commit is contained in:
Jagrit Digani
2025-03-20 14:25:38 -07:00
committed by GitHub
parent 9adcd1a650
commit b42d13ec84
2 changed files with 6 additions and 7 deletions

View File

@@ -72,8 +72,8 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
@@ -524,9 +524,8 @@ class TestSDPA(mlx_tests.MLXTestCase):
list(out_ref.shape), list(out_fst.shape)
)
self.assertTrue(
mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
)
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
self.assertLessEqual(mx.max(diff).item(), atol)
if __name__ == "__main__":