diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 96ddd32b3..e96e7234d 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -34,6 +34,7 @@ Array array.prod array.reciprocal array.reshape + array.round array.rsqrt array.sin array.split diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index dcdf0ffd9..811728307 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -73,6 +73,7 @@ Operations prod reciprocal reshape + round rsqrt save savez diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index ce00559f2..912c460e6 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -47,6 +47,7 @@ DEFAULT(Pad) DEFAULT(Partition) DEFAULT(RandomBits) DEFAULT(Reshape) +DEFAULT(Round) DEFAULT(Scatter) DEFAULT(Sigmoid) DEFAULT(Sign) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index a0985f5f8..126d953f5 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -65,6 +65,7 @@ DEFAULT(Power) DEFAULT(RandomBits) DEFAULT(Reduce) DEFAULT(Reshape) +DEFAULT(Round) DEFAULT(Scan) DEFAULT(Scatter) DEFAULT(Sigmoid) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 62775c11d..2218c708c 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -466,6 +466,17 @@ void Reshape::eval(const std::vector& inputs, array& out) { } } +void Round::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (not is_integral(in.dtype())) { + unary_fp(in, out, RoundOp()); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + void Sigmoid::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index dbbaf6993..bbf118b9b 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -53,6 +53,17 @@ struct SignOp { } }; +struct RoundOp { + template + T operator()(T x) { + return std::round(x); + } + + complex64_t operator()(complex64_t x) { + return {std::round(x.real()), std::round(x.imag())}; + } +}; + template void unary_op(const array& a, array& out, Op op) { const T* a_ptr = a.data(); diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 80becce3f..4de326f64 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -133,6 +133,11 @@ struct Negative { template T operator()(T x) { return -x; }; }; +struct Round { + template T operator()(T x) { return metal::round(x); }; + template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; }; +}; + struct Sigmoid { template T operator()(T x) { @@ -300,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt) instantiate_unary_float(rsqrt, Rsqrt) instantiate_unary_float(tan, Tan) instantiate_unary_float(tanh, Tanh) +instantiate_unary_float(round, Round) instantiate_unary_all(abs, complex64, complex64_t, Abs) instantiate_unary_all(cos, complex64, complex64_t, Cos) @@ -310,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin) instantiate_unary_all(sinh, complex64, complex64_t, Sinh) instantiate_unary_all(tan, complex64, complex64_t, Tan) instantiate_unary_all(tanh, complex64, complex64_t, Tanh) +instantiate_unary_all(round, complex64, complex64_t, Round) instantiate_unary_all(lnot, bool_, bool, LogicalNot) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index cca10464d..578d4e382 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -563,6 +563,17 @@ void Reshape::eval_gpu(const std::vector& inputs, array& out) { } } +void Round::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (not is_integral(in.dtype())) { + unary_op(inputs, out, "round"); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "sigmoid"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 1be820bc8..b9cde4426 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -61,6 +61,7 @@ NO_GPU(Power) NO_GPU(RandomBits) NO_GPU(Reduce) NO_GPU(Reshape) +NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(Sigmoid) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 125672af7..c9bafaac1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1834,6 +1834,21 @@ array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); } +array round(const array& a, int decimals, StreamOrDevice s /* = {} */) { + if (decimals == 0) { + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + } + + auto dtype = at_least_float(a.dtype()); + float scale = std::pow(10, decimals); + auto result = multiply(a, array(scale, dtype), s); + result = round(result, 0, s); + result = multiply(result, array(1 / scale, dtype), s); + + return astype(result, a.dtype(), s); +} + array matmul( const array& in_a, const array& in_b, diff --git a/mlx/ops.h b/mlx/ops.h index 7046f7a39..0b99a282f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -794,6 +794,12 @@ array erfinv(const array& a, StreamOrDevice s = {}); /** Stop the flow of gradients. */ array stop_gradient(const array& a, StreamOrDevice s = {}); +/** Round a floating point number */ +array round(const array& a, int decimals, StreamOrDevice s = {}); +inline array round(const array& a, StreamOrDevice s = {}) { + return round(a, 0, s); +} + /** Matrix-matrix multiplication. */ array matmul(const array& a, const array& b, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e1f17ae7d..6d275097c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1888,6 +1888,30 @@ bool Reduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; } +std::vector Round::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Round::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return zeros_like(primals[0], stream()); +} + +std::pair Round::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {round(inputs[0], stream()), axes[0]}; +} + std::pair Scan::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 8141d0fca..8916c4fa1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1206,6 +1206,25 @@ class Reduce : public Primitive { void eval(const std::vector& inputs, array& out); }; +class Round : public Primitive { + public: + explicit Round(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Round) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Scan : public Primitive { public: enum ReduceType { Max, Min, Sum, Prod }; diff --git a/python/src/array.cpp b/python/src/array.cpp index 36f856c8c..2c54d8008 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1148,5 +1148,15 @@ void init_array(py::module_& m) { "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, - "See :func:`cummin`."); + "See :func:`cummin`.") + .def( + "round", + [](const array& a, int decimals, StreamOrDevice s) { + return round(a, decimals, s); + }, + py::pos_only(), + "decimals"_a = 0, + py::kw_only(), + "stream"_a = none, + "See :func:`round`."); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 438b6d80f..b16b3ef0f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2922,4 +2922,33 @@ void init_ops(py::module_& m) { Returns: result (array): The output containing elements selected from ``x`` and ``y``. )pbdoc"); + m.def( + "round", + [](const array& a, int decimals, StreamOrDevice s) { + return round(a, decimals, s); + }, + "a"_a, + py::pos_only(), + "decimals"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array + + Round to the given number of decimals. + + Bascially performs: + + .. code-block:: python + + s = 10**decimals + x = round(x * s) / s + + Args: + a (array): Input array + decimals (int): Number of decimal places to round to. (default: 0) + + Returns: + result (array): An array of the same type as ``a`` rounded to the given number of decimals. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0207fa630..5e0ac0062 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -372,7 +372,35 @@ class TestOps(mlx_tests.MLXTestCase): self.assertListEqual(mx.ceil(x).tolist(), expected) with self.assertRaises(ValueError): - mx.floor(mx.array([22 + 3j, 19 + 98j])) + mx.ceil(mx.array([22 + 3j, 19 + 98j])) + + def test_round(self): + # float + x = mx.array( + [0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf] + ) + expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf] + self.assertListEqual(mx.round(x).tolist(), expected) + + # complex + y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j])) + self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j]) + + # decimals + y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0) + y1 = mx.round(mx.array([15, 122], mx.int32), decimals=-1) + y2 = mx.round(mx.array([15, 122], mx.int32), decimals=-2) + self.assertEqual(y0.dtype, mx.int32) + self.assertEqual(y1.dtype, mx.int32) + self.assertEqual(y2.dtype, mx.int32) + self.assertListEqual(y0.tolist(), [15, 122]) + self.assertListEqual(y1.tolist(), [20, 120]) + self.assertListEqual(y2.tolist(), [0, 100]) + + y1 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=1) + y2 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=2) + self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5]))) + self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47]))) def test_transpose_noargs(self): x = mx.array([[0, 1, 1], [1, 0, 0]]) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c62fb4a39..6dccf4bc1 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -862,6 +862,15 @@ TEST_CASE("test arithmetic unary ops") { CHECK_THROWS_AS(ceil(x), std::invalid_argument); } + // Test round + { + array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6}); + CHECK(array_equal(round(x), array({1, -1, 2, -2, 2, 3})).item()); + + x = array({11, 222, 32}); + CHECK(array_equal(round(x, -1), array({10, 220, 30})).item()); + } + // Test exponential { array x(0.0);