Implements divide for integer types and adds floor_divide op (#228)

* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
This commit is contained in:
Angelos Katharopoulos 2023-12-19 20:12:19 -08:00 committed by GitHub
parent de892cb66c
commit 2807c6aff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 14 deletions

View File

@ -41,8 +41,9 @@ Operations
exp exp
expand_dims expand_dims
eye eye
floor
flatten flatten
floor
floor_divide
full full
greater greater
greater_equal greater_equal

View File

@ -357,7 +357,7 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, complex64, complex64_t, bool, op) instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add) instantiate_binary_types(add, Add)
instantiate_binary_float(div, Divide) instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal) instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater) instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual) instantiate_binary_types_bool(geq, GreaterEqual)

View File

@ -111,6 +111,13 @@ constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
} }
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a.real * b.real + a.imag * b.imag;
auto y = a.imag * b.real - a.real * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real)); auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag)); auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@ -1636,6 +1636,20 @@ array operator/(const array& a, double b) {
return divide(a, array(b)); return divide(a, array(b));
} }
array floor_divide(
const array& a,
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) {
return floor(divide(a, b, s), s);
}
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
}
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays( auto inputs = broadcast_arrays(

View File

@ -709,6 +709,9 @@ array operator/(const array& a, const array& b);
array operator/(double a, const array& b); array operator/(double a, const array& b);
array operator/(const array& a, double b); array operator/(const array& a, double b);
/** Compute integer division. Equivalent to doing floor(a / x). */
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
/** Compute the element-wise remainder of division */ /** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {}); array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b); array operator%(const array& a, const array& b);

View File

@ -636,8 +636,7 @@ void init_array(py::module_& m) {
"__floordiv__", "__floordiv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype()); return floor_divide(a, b);
return astype(divide(a, b), t);
}, },
"other"_a) "other"_a)
.def( .def(
@ -650,8 +649,7 @@ void init_array(py::module_& m) {
"__rfloordiv__", "__rfloordiv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype()); return floor_divide(b, a);
return astype(divide(b, a), t);
}, },
"other"_a) "other"_a)
.def( .def(

View File

@ -303,6 +303,32 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The quotient ``a / b``. array: The quotient ``a / b``.
)pbdoc"); )pbdoc");
m.def(
"floor_divide",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return floor_divide(a, b, s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise integer division.
If either array is a floating point type then it is equivalent to
calling :func:`floor` after :func:`divide`.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The quotient ``a // b``.
)pbdoc");
m.def( m.def(
"remainder", "remainder",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract", "subtract",
"multiply", "multiply",
"divide", "divide",
"floor_divide",
"remainder", "remainder",
"equal", "equal",
"not_equal", "not_equal",
@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract", "subtract",
"multiply", "multiply",
"divide", "divide",
"floor_divide",
"maximum", "maximum",
"minimum", "minimum",
"power", "power",
@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase):
"uint32", "uint32",
"uint64", "uint64",
] ]
float_dtypes = ["float16", "float32"] float_dtypes = ["float16", "float32"]
dtypes = ( dtypes = {
float_dtypes "divide": float_dtypes,
if op in ("divide", "power") "power": float_dtypes,
else (int_dtypes + float_dtypes) "floor_divide": ["float32"] + int_dtypes,
) }
dtypes = dtypes.get(op, int_dtypes + float_dtypes)
for dtype in dtypes: for dtype in dtypes:
atol = 1e-3 if dtype == "float16" else 1e-6 atol = 1e-3 if dtype == "float16" else 1e-6
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
x1_ = x1.astype(getattr(np, dtype)) m = 10 if dtype in int_dtypes else 1
x2_ = x2.astype(getattr(np, dtype)) x1_ = (x1 * m).astype(getattr(np, dtype))
x2_ = (x2 * m).astype(getattr(np, dtype))
y1_ = mx.array(x1_) y1_ = mx.array(x1_)
y2_ = mx.array(x2_) y2_ = mx.array(x2_)
test_ops( test_ops(