mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 00:54:37 +08:00
@@ -582,6 +582,25 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_empty_matmul(self):
|
||||
a = mx.array([[], []]).T
|
||||
b = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||
c = a @ b
|
||||
mx.eval(c)
|
||||
self.assertEqual(c.shape, (0, 2))
|
||||
|
||||
a = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||
b = mx.array([[], []])
|
||||
c = a @ b
|
||||
mx.eval(c)
|
||||
self.assertEqual(c.shape, (2, 0))
|
||||
|
||||
a = mx.array([[], []]).T
|
||||
b = mx.array([[], []])
|
||||
c = a @ b
|
||||
mx.eval(c)
|
||||
self.assertEqual(c.shape, (0, 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user