Use same accumulation precision in gemv as gemm (#1962)

* use same accumulation precision in gemv as gemm

* faster

* fix compile
This commit is contained in:
Awni Hannun
2025-03-16 07:13:24 -07:00
committed by GitHub
parent 2770a10240
commit c6ea2ba329
3 changed files with 79 additions and 53 deletions

View File

@@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)
c = mx.random.normal([N, 64]).astype(t)
out_gemv = a @ c
out_gemm = (b @ c)[0]
self.assertTrue(mx.allclose(out_gemv, out_gemm))
if __name__ == "__main__":
unittest.main()