diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp index 654be8074..b26a16bab 100644 --- a/mlx/backend/accelerate/matmul.cpp +++ b/mlx/backend/accelerate/matmul.cpp @@ -46,6 +46,9 @@ inline void matmul_cblas_general( size_t N = b.shape(-1); size_t K = a.shape(-1); + if (M == 0 || N == 0) { + return; + } if (K == 0) { std::memset(static_cast(out.data()), 0, out.nbytes()); return; @@ -94,6 +97,9 @@ inline void matmul_bnns_general( size_t N = b.shape(-1); size_t K = a.shape(-1); + if (M == 0 || N == 0) { + return; + } if (K == 0) { std::memset(static_cast(out.data()), 0, out.nbytes()); return; diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 06c9db58d..6befc8eb9 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -132,7 +132,9 @@ inline void matmul_common_general( size_t M = a.shape(-2); size_t N = b.shape(-1); size_t K = a.shape(-1); - + if (M == 0 || N == 0) { + return; + } if (K == 0) { std::memset(static_cast(out.data()), 0, out.nbytes()); return; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 019624a70..01ee6d388 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1323,8 +1323,8 @@ array mean( for (int axis : axes) { if (axis < -ndim || axis >= ndim) { std::ostringstream msg; - msg << "[mean] axis " << axis + " is out of bounds for array with " - << ndim + " dimensions."; + msg << "[mean] axis " << axis << " is out of bounds for array with " + << ndim << " dimensions."; throw std::invalid_argument(msg.str()); } } @@ -1364,7 +1364,7 @@ array var( if (ddof != 0) { auto nelements = compute_number_of_elements(a, axes); - float factor = nelements / (nelements - ddof); + auto factor = nelements / static_cast(std::max(nelements - ddof, 0)); v = multiply(v, array(factor, dtype), s); } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index b7a24caf2..c2c1cc2a2 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -582,6 +582,25 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_empty_matmul(self): + a = mx.array([[], []]).T + b = mx.array([[1.0, 2.0], [2.0, 3.0]]) + c = a @ b + mx.eval(c) + self.assertEqual(c.shape, (0, 2)) + + a = mx.array([[1.0, 2.0], [2.0, 3.0]]) + b = mx.array([[], []]) + c = a @ b + mx.eval(c) + self.assertEqual(c.shape, (2, 0)) + + a = mx.array([[], []]).T + b = mx.array([[], []]) + c = a @ b + mx.eval(c) + self.assertEqual(c.shape, (0, 0)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index add06c729..6ac46779d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math import unittest @@ -690,6 +690,14 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0]) self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25]) + x = mx.array([1.0, 2.0]) + out = mx.var(x, ddof=2) + self.assertEqual(out.item(), float("inf")) + + x = mx.array([1.0, 2.0]) + out = mx.var(x, ddof=3) + self.assertEqual(out.item(), float("inf")) + def test_abs(self): a = mx.array([-1.0, 1.0, -2.0, 3.0]) result = mx.abs(a)