Fix gemv broadcasting bug (#6)

* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
This commit is contained in:
Jagrit Digani
2023-12-05 14:15:43 -08:00
committed by GitHub
parent 49cda449b1
commit d518b3b6a5
4 changed files with 492 additions and 203 deletions

View File

@@ -340,6 +340,7 @@ class TestBlas(mlx_tests.MLXTestCase):
((32, 128, 64), (32, 64, 1)),
((128, 64), (32, 64, 1)),
((32, 128, 64), (64, 1)),
((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
@@ -350,6 +351,7 @@ class TestBlas(mlx_tests.MLXTestCase):
((32, 1, 128), (32, 128, 64)),
((32, 1, 128), (128, 64)),
((1, 128), (32, 128, 64)),
((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype