mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
fix gemv regression (#2445)
This commit is contained in:
@@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||
|
||||
def test_matmul_unaligned(self):
|
||||
if not mx.metal.is_available():
|
||||
if not mx.is_available(mx.gpu):
|
||||
return
|
||||
|
||||
for dtype in self.dtypes:
|
||||
@@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
shape_b = (dim + p, dim + p)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
def test_matvec_unaligned(self):
|
||||
a = mx.random.normal(shape=(4, 128))
|
||||
b = mx.random.normal(shape=(129,))[1:]
|
||||
out = a @ b
|
||||
np_out = np.array(a) @ np.array(b)
|
||||
self.assertTrue(np.allclose(out, np_out))
|
||||
|
||||
def test_matmul_shapes(self):
|
||||
if not mx.metal.is_available():
|
||||
if not mx.is_available(mx.gpu):
|
||||
return
|
||||
|
||||
shapes = [
|
||||
@@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
def test_gemv_gemm_same_precision(self):
|
||||
mx.random.seed(0)
|
||||
N = 256
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
t = mx.bfloat16
|
||||
a = mx.random.normal([1, N]).astype(t)
|
||||
b = mx.concatenate([a, a], axis=0).astype(t)
|
||||
|
Reference in New Issue
Block a user