fix gemv regression (#2445)

This commit is contained in:
Awni Hannun
2025-07-30 14:23:01 -07:00
committed by GitHub
parent b405591249
commit d32519c8ee
3 changed files with 36 additions and 9 deletions

View File

@@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
def test_matmul_unaligned(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return
for dtype in self.dtypes:
@@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase):
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)
def test_matvec_unaligned(self):
a = mx.random.normal(shape=(4, 128))
b = mx.random.normal(shape=(129,))[1:]
out = a @ b
np_out = np.array(a) @ np.array(b)
self.assertTrue(np.allclose(out, np_out))
def test_matmul_shapes(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return
shapes = [
@@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase):
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
if mx.is_available(mx.gpu):
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)