mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Sdpa fix (#1558)
This commit is contained in:

committed by
GitHub

parent
09bc32f62f
commit
62f297b51d
@@ -167,6 +167,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||
|
||||
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
||||
|
||||
atol = 1e-6
|
||||
y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale)
|
||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user