mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm * faster * fix compile
This commit is contained in:
@@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_gemv_gemm_same_precision(self):
|
||||
mx.random.seed(0)
|
||||
N = 256
|
||||
if mx.metal.is_available():
|
||||
t = mx.bfloat16
|
||||
a = mx.random.normal([1, N]).astype(t)
|
||||
b = mx.concatenate([a, a], axis=0).astype(t)
|
||||
c = mx.random.normal([N, 64]).astype(t)
|
||||
out_gemv = a @ c
|
||||
out_gemm = (b @ c)[0]
|
||||
self.assertTrue(mx.allclose(out_gemv, out_gemm))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user