minor fixes (#631)

* minor fixes

* var with ddof >= nelements
This commit is contained in:
Awni Hannun
2024-02-05 13:27:49 -08:00
committed by GitHub
parent d75ae52ecd
commit d40a04f8dc
5 changed files with 40 additions and 5 deletions

View File

@@ -582,6 +582,25 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_empty_matmul(self):
a = mx.array([[], []]).T
b = mx.array([[1.0, 2.0], [2.0, 3.0]])
c = a @ b
mx.eval(c)
self.assertEqual(c.shape, (0, 2))
a = mx.array([[1.0, 2.0], [2.0, 3.0]])
b = mx.array([[], []])
c = a @ b
mx.eval(c)
self.assertEqual(c.shape, (2, 0))
a = mx.array([[], []]).T
b = mx.array([[], []])
c = a @ b
mx.eval(c)
self.assertEqual(c.shape, (0, 0))
if __name__ == "__main__":
unittest.main()