add python testing for cuda with ability to skip list of tests (#2295)

This commit is contained in:
Awni Hannun
2025-06-15 10:56:48 -07:00
committed by GitHub
parent 580776559b
commit 4fda5fbdf9
36 changed files with 220 additions and 35 deletions

View File

@@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
@@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)