diff --git a/.circleci/config.yml b/.circleci/config.yml index 2d7c9f771..25fb71fb5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,7 +42,7 @@ jobs: - run: name: Run Python tests command: | - python3 -m unittest discover python/tests + python3 -m unittest discover python/tests -v # TODO: Reenable when extension api becomes stable # - run: # name: Build example extension diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 5a7720b0e..3160a1833 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -44,6 +44,13 @@ def time_matmul(): time_fn(mx.matmul, a, b) +def time_maximum(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + b = mx.random.uniform(shape=(32, 1024, 1024)) + mx.eval(a, b) + time_fn(mx.maximum, a, b) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -101,6 +108,7 @@ if __name__ == "__main__": time_add() time_matmul() + time_maximum() time_exp() time_negative() time_logsumexp() diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp index 254f3fc4c..654be8074 100644 --- a/mlx/backend/accelerate/matmul.cpp +++ b/mlx/backend/accelerate/matmul.cpp @@ -46,6 +46,11 @@ inline void matmul_cblas_general( size_t N = b.shape(-1); size_t K = a.shape(-1); + if (K == 0) { + std::memset(static_cast(out.data()), 0, out.nbytes()); + return; + } + for (int i = 0; i < (a.size() / (M * K)); ++i) { cblas_sgemm( CblasRowMajor, @@ -89,6 +94,11 @@ inline void matmul_bnns_general( size_t N = b.shape(-1); size_t K = a.shape(-1); + if (K == 0) { + std::memset(static_cast(out.data()), 0, out.nbytes()); + return; + } + BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); const BNNSLayerParametersBroadcastMatMul gemm_params{ @@ -201,4 +211,4 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index be8741a54..bee736b50 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -50,6 +50,8 @@ DEFAULT(LogicalNot) DEFAULT(LogicalAnd) DEFAULT(LogicalOr) DEFAULT(LogAddExp) +DEFAULT(Maximum) +DEFAULT(Minimum) DEFAULT(NotEqual) DEFAULT(Pad) DEFAULT(Partition) @@ -396,47 +398,6 @@ void Log1p::eval_cpu(const std::vector& inputs, array& out) { } } -void Maximum::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - if (out.dtype() == float32) { - binary( - a, - b, - out, - [](auto x, auto y) { return (x > y) ? x : y; }, - UseDefaultBinaryOp(), - UseDefaultBinaryOp(), - [](const auto* a, const auto* b, auto* out, int n) { - vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n); - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); - } -} - -void Minimum::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (out.dtype() == float32) { - binary( - a, - b, - out, - [](auto x, auto y) { return (x < y) ? x : y; }, - UseDefaultBinaryOp(), - UseDefaultBinaryOp(), - [](const auto* a, const auto* b, auto* out, int n) { - vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n); - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); - } -} - void Multiply::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 8e3de02d3..a51d22d0f 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -233,14 +233,33 @@ void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); + + if (is_floating_point(out.dtype())) { + binary(a, b, out, [](auto x, auto y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + }); + } else { + binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); + } } void Minimum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); + if (is_floating_point(out.dtype())) { + binary(a, b, out, [](auto x, auto y) { + if (std::isnan(x)) { + return x; + } + return (x < y) ? x : y; + }); + } else { + binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); + } } void Multiply::eval(const std::vector& inputs, array& out) { diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 945451a2a..622f33fd4 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -6,6 +6,8 @@ #include #endif +#include + #include "mlx/array.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" @@ -128,6 +130,11 @@ inline void matmul_common_general( size_t N = b.shape(-1); size_t K = a.shape(-1); + if (K == 0) { + std::memset(static_cast(out.data()), 0, out.nbytes()); + return; + } + for (int i = 0; i < (a.size() / (M * K)); ++i) { cblas_sgemm( CblasRowMajor, diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 12b6d1b1b..1b84c70a5 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -58,6 +58,9 @@ struct LessEqual { struct LogAddExp { template T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } constexpr T inf = metal::numeric_limits::infinity(); T maxval = metal::max(x, y); T minval = metal::min(x, y); @@ -67,20 +70,48 @@ struct LogAddExp { }; struct Maximum { - template T operator()(T x, T y) { return metal::max(x, y); } + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } template <> complex64_t operator()(complex64_t x, complex64_t y) { - return x >= y ? x : y; + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; } }; struct Minimum { - template T operator()(T x, T y) { return metal::min(x, y); } + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } template <> complex64_t operator()(complex64_t x, complex64_t y) { - return x <= y ? x : y; + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; } }; @@ -389,4 +420,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) -instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) \ No newline at end of file +instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f1f83a6b7..d2768d32c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -189,6 +189,9 @@ array full( const array& vals, Dtype dtype, StreamOrDevice s /* = {} */) { + if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { + throw std::invalid_argument("[full] Negative dimensions not allowed."); + } auto in = broadcast_to(astype(vals, dtype, s), shape, s); return array(shape, dtype, std::make_unique(to_stream(s)), {in}); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index bbc6e6a84..0797c85eb 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -386,6 +386,11 @@ class TestOps(mlx_tests.MLXTestCase): expected = [0, -7, 3] self.assertListEqual(mx.minimum(x, y).tolist(), expected) + a = mx.array([float("nan")]) + b = mx.array([0.0]) + self.assertTrue(math.isnan(mx.minimum(a, b).item())) + self.assertTrue(math.isnan(mx.minimum(b, a).item())) + def test_maximum(self): x = mx.array([0.0, -5, 10.0]) y = mx.array([1.0, -7.0, 3.0]) @@ -393,6 +398,11 @@ class TestOps(mlx_tests.MLXTestCase): expected = [1, -5, 10] self.assertListEqual(mx.maximum(x, y).tolist(), expected) + a = mx.array([float("nan")]) + b = mx.array([0.0]) + self.assertTrue(math.isnan(mx.maximum(a, b).item())) + self.assertTrue(math.isnan(mx.maximum(b, a).item())) + def test_floor(self): x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]) expected = [-23, 19, -27, 9, 0, -np.inf, np.inf] @@ -760,6 +770,10 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) + b = mx.array([0.0]) + self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) + def test_log(self): a = mx.array([1, 0.5, 10, 100]) result = mx.log(a) @@ -1761,6 +1775,16 @@ class TestOps(mlx_tests.MLXTestCase): ) self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile) + def test_empty_matmuls(self): + a = mx.array([]) + b = mx.array([]) + self.assertEqual(mx.inner(a, b).item(), 0.0) + + a = mx.zeros((10, 0)) + b = mx.zeros((0, 10)) + out = a @ b + self.assertTrue(mx.array_equal(out, mx.zeros((10, 10)))) + if __name__ == "__main__": unittest.main() diff --git a/tests/creations_tests.cpp b/tests/creations_tests.cpp index ea28638af..528e9fc90 100644 --- a/tests/creations_tests.cpp +++ b/tests/creations_tests.cpp @@ -151,6 +151,12 @@ TEST_CASE("test astype") { } TEST_CASE("test full") { + // Check throws on bad shape + { + CHECK_THROWS(full({-5, 0}, 0)); + CHECK_THROWS(full({0, -5}, 0)); + } + // Check full works for different types { auto x = full({}, 0);