From b93c4cf378082941e0a0d85156a41cacdc8eb5bd Mon Sep 17 00:00:00 2001 From: Luca Arnaboldi Date: Thu, 14 Dec 2023 19:00:23 +0100 Subject: [PATCH] Floor and Ceil (#150) * Implements Floor and Ceil Ops --- docs/src/python/ops.rst | 2 + mlx/backend/accelerate/primitives.cpp | 2 + mlx/backend/common/default_primitives.cpp | 2 + mlx/backend/common/primitives.cpp | 22 +++++++++ mlx/backend/metal/kernels/unary.metal | 28 ++++++++++++ mlx/backend/metal/primitives.cpp | 8 ++++ mlx/backend/no_metal/primitives.cpp | 2 + mlx/ops.cpp | 15 +++++++ mlx/ops.h | 6 +++ mlx/primitives.cpp | 54 +++++++++++++++++++++-- mlx/primitives.h | 38 ++++++++++++++++ python/src/ops.cpp | 36 +++++++++++++++ python/tests/test_ops.py | 16 +++++++ tests/ops_tests.cpp | 23 ++++++++++ 14 files changed, 250 insertions(+), 4 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index b9a4c9066..5ef1f54c2 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -26,6 +26,7 @@ Operations argsort array_equal broadcast_to + ceil concatenate convolve conv1d @@ -39,6 +40,7 @@ Operations exp expand_dims eye + floor full greater greater_equal diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 3334fb830..ce00559f2 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -26,12 +26,14 @@ DEFAULT(ArgReduce) DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(Broadcast) +DEFAULT(Ceil) DEFAULT(Concatenate) DEFAULT(Copy) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) DEFAULT(FFT) +DEFAULT(Floor) DEFAULT(Gather) DEFAULT(Greater) DEFAULT(GreaterEqual) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index db0ae57bb..a0985f5f8 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -29,6 +29,7 @@ DEFAULT(ArgSort) DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) +DEFAULT(Ceil) DEFAULT(Concatenate) DEFAULT(Convolution) DEFAULT(Copy) @@ -41,6 +42,7 @@ DEFAULT(Erf) DEFAULT(ErfInv) DEFAULT(Exp) DEFAULT(FFT) +DEFAULT(Floor) DEFAULT(Full) DEFAULT(Gather) DEFAULT(Greater) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 4049af2eb..62775c11d 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -167,6 +167,17 @@ void Broadcast::eval(const std::vector& inputs, array& out) { out.copy_shared_buffer(in, strides, flags, in.data_size()); } +void Ceil::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (not is_integral(in.dtype())) { + unary_fp(in, out, [](auto x) { return std::ceil(x); }); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + void Concatenate::eval(const std::vector& inputs, array& out) { std::vector sizes; sizes.push_back(0); @@ -287,6 +298,17 @@ void Exp::eval(const std::vector& inputs, array& out) { } } +void Floor::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (not is_integral(in.dtype())) { + unary_fp(in, out, [](auto x) { return std::floor(x); }); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + void Full::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index bfda9938c..80becce3f 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -43,6 +43,19 @@ struct ArcTanh { template T operator()(T x) { return metal::precise::atanh(x); }; }; +struct Ceil { + template T operator()(T x) { return metal::ceil(x); }; + template <> int8_t operator()(int8_t x) { return x; }; + template <> int16_t operator()(int16_t x) { return x; }; + template <> int32_t operator()(int32_t x) { return x; }; + template <> int64_t operator()(int64_t x) { return x; }; + template <> uint8_t operator()(uint8_t x) { return x; }; + template <> uint16_t operator()(uint16_t x) { return x; }; + template <> uint32_t operator()(uint32_t x) { return x; }; + template <> uint64_t operator()(uint64_t x) { return x; }; + template <> bool operator()(bool x) { return x; }; +}; + struct Cos { template T operator()(T x) { return metal::precise::cos(x); }; @@ -83,6 +96,19 @@ struct Exp { } }; +struct Floor { + template T operator()(T x) { return metal::floor(x); }; + template <> int8_t operator()(int8_t x) { return x; }; + template <> int16_t operator()(int16_t x) { return x; }; + template <> int32_t operator()(int32_t x) { return x; }; + template <> int64_t operator()(int64_t x) { return x; }; + template <> uint8_t operator()(uint8_t x) { return x; }; + template <> uint16_t operator()(uint16_t x) { return x; }; + template <> uint32_t operator()(uint32_t x) { return x; }; + template <> uint64_t operator()(uint64_t x) { return x; }; + template <> bool operator()(bool x) { return x; }; +}; + struct Log { template T operator()(T x) { return metal::precise::log(x); }; }; @@ -253,9 +279,11 @@ instantiate_unary_float(arcsin, ArcSin) instantiate_unary_float(arcsinh, ArcSinh) instantiate_unary_float(arctan, ArcTan) instantiate_unary_float(arctanh, ArcTanh) +instantiate_unary_types(ceil, Ceil) instantiate_unary_float(cos, Cos) instantiate_unary_float(cosh, Cosh) instantiate_unary_float(exp, Exp) +instantiate_unary_types(floor, Floor) instantiate_unary_float(log, Log) instantiate_unary_float(log2, Log2) instantiate_unary_float(log10, Log10) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 52be03d5c..cca10464d 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -450,6 +450,14 @@ void Minimum::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "min"); } +void Floor::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "floor"); +} + +void Ceil::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "ceil"); +} + void Multiply::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "mul"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index ed4704465..1be820bc8 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -24,6 +24,7 @@ NO_GPU(ArgSort) NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(Broadcast) +NO_GPU(Ceil) NO_GPU(Concatenate) NO_GPU(Convolution) NO_GPU(Copy) @@ -36,6 +37,7 @@ NO_GPU(Erf) NO_GPU(ErfInv) NO_GPU(Exp) NO_GPU(FFT) +NO_GPU(Floor) NO_GPU(Full) NO_GPU(Gather) NO_GPU(Greater) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e85683a48..066242486 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1498,6 +1498,21 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { inputs); } +array floor(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() == complex64) { + throw std::invalid_argument("[floor] Not supported for complex64."); + } + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + +array ceil(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() == complex64) { + throw std::invalid_argument("[floor] Not supported for complex64."); + } + return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + array square(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); diff --git a/mlx/ops.h b/mlx/ops.h index 560697259..686dbfb79 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -677,6 +677,12 @@ array maximum(const array& a, const array& b, StreamOrDevice s = {}); /** Element-wise minimum between two arrays. */ array minimum(const array& a, const array& b, StreamOrDevice s = {}); +/** Floor the element of an array. **/ +array floor(const array& a, StreamOrDevice s = {}); + +/** Ceil the element of an array. **/ +array ceil(const array& a, StreamOrDevice s = {}); + /** Square the elements of an array. */ array square(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 58e539e7a..58d952e92 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -441,6 +441,30 @@ bool Broadcast::is_equivalent(const Primitive& other) const { return shape_ == b_other.shape_; } +std::vector Ceil::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Ceil::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return zeros_like(primals[0], stream()); +} + +std::pair Ceil::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {ceil(inputs[0], stream()), axes[0]}; +} + std::vector Concatenate::vjp( const std::vector& primals, const array& cotan, @@ -748,8 +772,7 @@ std::vector Remainder::vjp( vjps.push_back(cotan); } else { auto x_over_y = divide(primals[0], primals[1], stream()); - // TODO: Replace with a proper floor when available - x_over_y = astype(x_over_y, int32, stream()); + x_over_y = floor(x_over_y, stream()); vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream())); } } @@ -766,8 +789,7 @@ array Remainder::jvp( return tangents[i]; } else { auto x_over_y = divide(primals[0], primals[1], stream()); - // TODO: Replace with a proper floor when available - x_over_y = astype(x_over_y, int32, stream()); + x_over_y = floor(x_over_y, stream()); return negative(multiply(x_over_y, tangents[i], stream()), stream()); } }; @@ -976,6 +998,30 @@ array FFT::jvp( } } +std::vector Floor::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Floor::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return zeros_like(primals[0], stream()); +} + +std::pair Floor::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {floor(inputs[0], stream()), axes[0]}; +} + std::vector Full::vjp( const std::vector& primals, const array& cotan, diff --git a/mlx/primitives.h b/mlx/primitives.h index c6ea72b3f..c70de50d8 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -404,6 +404,25 @@ class Broadcast : public Primitive { void eval(const std::vector& inputs, array& out); }; +class Ceil : public Primitive { + public: + explicit Ceil(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Ceil) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Concatenate : public Primitive { public: explicit Concatenate(Stream stream, int axis) @@ -662,6 +681,25 @@ class FFT : public Primitive { void eval(const std::vector& inputs, array& out); }; +class Floor : public Primitive { + public: + explicit Floor(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Floor) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Full : public Primitive { public: explicit Full(Stream stream) : Primitive(stream){}; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a5ae163f4..76a057b20 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1555,6 +1555,42 @@ void init_ops(py::module_& m) { Returns: array: The max of ``a`` and ``b``. )pbdoc"); + m.def( + "floor", + &mlx::core::floor, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Element-wise floor. + + Args: + a (array): Input array. + + Returns: + array: The floor of ``a``. + )pbdoc"); + m.def( + "ceil", + &mlx::core::ceil, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Element-wise ceil. + + Args: + a (array): Input array. + + Returns: + array: The ceil of ``a``. + )pbdoc"); m.def( "transpose", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d0ca54365..09263fefe 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -334,6 +334,22 @@ class TestOps(mlx_tests.MLXTestCase): expected = [1, -5, 10] self.assertListEqual(mx.maximum(x, y).tolist(), expected) + def test_floor(self): + x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]) + expected = [-23, 19, -27, 9, 0, -np.inf, np.inf] + self.assertListEqual(mx.floor(x).tolist(), expected) + + with self.assertRaises(ValueError): + mx.floor(mx.array([22 + 3j, 19 + 98j])) + + def test_ceil(self): + x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]) + expected = [-22, 20, -27, 9, 0, -np.inf, np.inf] + self.assertListEqual(mx.ceil(x).tolist(), expected) + + with self.assertRaises(ValueError): + mx.floor(mx.array([22 + 3j, 19 + 98j])) + def test_transpose_noargs(self): x = mx.array([[0, 1, 1], [1, 0, 0]]) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index b3b32ea1d..d9df0262d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -773,6 +773,29 @@ TEST_CASE("test arithmetic unary ops") { constexpr float neginf = -std::numeric_limits::infinity(); + // Test floor and ceil + { + array x(1.0f); + CHECK_EQ(floor(x).item(), 1.0f); + CHECK_EQ(ceil(x).item(), 1.0f); + + x = array(1.5f); + CHECK_EQ(floor(x).item(), 1.0f); + CHECK_EQ(ceil(x).item(), 2.0f); + + x = array(-1.5f); + CHECK_EQ(floor(x).item(), -2.0f); + CHECK_EQ(ceil(x).item(), -1.0f); + + x = array(neginf); + CHECK_EQ(floor(x).item(), neginf); + CHECK_EQ(ceil(x).item(), neginf); + + x = array(std::complex(1.0f, 1.0f)); + CHECK_THROWS_AS(floor(x), std::invalid_argument); + CHECK_THROWS_AS(ceil(x), std::invalid_argument); + } + // Test exponential { array x(0.0);