From 2807c6aff03864f232c78ca61f77657dcf64f5f4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 19 Dec 2023 20:12:19 -0800 Subject: [PATCH] Implements divide for integer types and adds floor_divide op (#228) * Add floor_divide * Add floor_divide to the tests * Add floor_divide to the docs --- docs/src/python/ops.rst | 3 ++- mlx/backend/metal/kernels/binary.metal | 2 +- mlx/backend/metal/kernels/complex.h | 7 +++++++ mlx/ops.cpp | 14 ++++++++++++++ mlx/ops.h | 3 +++ python/src/array.cpp | 6 ++---- python/src/ops.cpp | 26 ++++++++++++++++++++++++++ python/tests/test_ops.py | 20 ++++++++++++-------- 8 files changed, 67 insertions(+), 14 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 1662dc788..731300047 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -41,8 +41,9 @@ Operations exp expand_dims eye - floor flatten + floor + floor_divide full greater greater_equal diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index b7bbe6f37..a4f599064 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -357,7 +357,7 @@ template instantiate_binary_all(name, complex64, complex64_t, bool, op) instantiate_binary_types(add, Add) -instantiate_binary_float(div, Divide) +instantiate_binary_types(div, Divide) instantiate_binary_types_bool(eq, Equal) instantiate_binary_types_bool(ge, Greater) instantiate_binary_types_bool(geq, GreaterEqual) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index 6bd427bf8..c9fedb797 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -111,6 +111,13 @@ constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; } +constexpr complex64_t operator/(complex64_t a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a.real * b.real + a.imag * b.imag; + auto y = a.imag * b.real - a.real * b.imag; + return {x / denom, y / denom}; +} + constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9a4a9b2a9..88fb30f53 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1636,6 +1636,20 @@ array operator/(const array& a, double b) { return divide(a, array(b)); } +array floor_divide( + const array& a, + const array& b, + StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + if (is_floating_point(dtype)) { + return floor(divide(a, b, s), s); + } + + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); + return array( + inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); +} + array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays( diff --git a/mlx/ops.h b/mlx/ops.h index 977dfc970..192ddb8d7 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -709,6 +709,9 @@ array operator/(const array& a, const array& b); array operator/(double a, const array& b); array operator/(const array& a, double b); +/** Compute integer division. Equivalent to doing floor(a / x). */ +array floor_divide(const array& a, const array& b, StreamOrDevice s = {}); + /** Compute the element-wise remainder of division */ array remainder(const array& a, const array& b, StreamOrDevice s = {}); array operator%(const array& a, const array& b); diff --git a/python/src/array.cpp b/python/src/array.cpp index 2c54d8008..74223cdda 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -636,8 +636,7 @@ void init_array(py::module_& m) { "__floordiv__", [](const array& a, const ScalarOrArray v) { auto b = to_array(v, a.dtype()); - auto t = promote_types(a.dtype(), b.dtype()); - return astype(divide(a, b), t); + return floor_divide(a, b); }, "other"_a) .def( @@ -650,8 +649,7 @@ void init_array(py::module_& m) { "__rfloordiv__", [](const array& a, const ScalarOrArray v) { auto b = to_array(v, a.dtype()); - auto t = promote_types(a.dtype(), b.dtype()); - return astype(divide(b, a), t); + return floor_divide(b, a); }, "other"_a) .def( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 39dbad5b8..3c4bdc018 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -303,6 +303,32 @@ void init_ops(py::module_& m) { Returns: array: The quotient ``a / b``. )pbdoc"); + m.def( + "floor_divide", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return floor_divide(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + + Element-wise integer division. + + If either array is a floating point type then it is equivalent to + calling :func:`floor` after :func:`divide`. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The quotient ``a // b``. + )pbdoc"); m.def( "remainder", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3f4cef9ed..ecfc6315a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase): "subtract", "multiply", "divide", + "floor_divide", "remainder", "equal", "not_equal", @@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase): "subtract", "multiply", "divide", + "floor_divide", "maximum", "minimum", "power", @@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase): "uint32", "uint64", ] - float_dtypes = ["float16", "float32"] - dtypes = ( - float_dtypes - if op in ("divide", "power") - else (int_dtypes + float_dtypes) - ) + dtypes = { + "divide": float_dtypes, + "power": float_dtypes, + "floor_divide": ["float32"] + int_dtypes, + } + dtypes = dtypes.get(op, int_dtypes + float_dtypes) + for dtype in dtypes: atol = 1e-3 if dtype == "float16" else 1e-6 with self.subTest(dtype=dtype): - x1_ = x1.astype(getattr(np, dtype)) - x2_ = x2.astype(getattr(np, dtype)) + m = 10 if dtype in int_dtypes else 1 + x1_ = (x1 * m).astype(getattr(np, dtype)) + x2_ = (x2 * m).astype(getattr(np, dtype)) y1_ = mx.array(x1_) y2_ = mx.array(x2_) test_ops(