mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv * Add relevant tests in test_blas.py
This commit is contained in:
@@ -340,6 +340,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
((32, 128, 64), (32, 64, 1)),
|
||||
((128, 64), (32, 64, 1)),
|
||||
((32, 128, 64), (64, 1)),
|
||||
((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
@@ -350,6 +351,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
((32, 1, 128), (32, 128, 64)),
|
||||
((32, 1, 128), (128, 64)),
|
||||
((1, 128), (32, 128, 64)),
|
||||
((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
|
Reference in New Issue
Block a user