mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide * Add floor_divide to the tests * Add floor_divide to the docs
This commit is contained in:
parent
de892cb66c
commit
2807c6aff0
@ -41,8 +41,9 @@ Operations
|
||||
exp
|
||||
expand_dims
|
||||
eye
|
||||
floor
|
||||
flatten
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
|
@ -357,7 +357,7 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_float(div, Divide)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
|
@ -111,6 +111,13 @@ constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a.real * b.real + a.imag * b.imag;
|
||||
auto y = a.imag * b.real - a.real * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
|
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -1636,6 +1636,20 @@ array operator/(const array& a, double b) {
|
||||
return divide(a, array(b));
|
||||
}
|
||||
|
||||
array floor_divide(
|
||||
const array& a,
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_floating_point(dtype)) {
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
return array(
|
||||
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
|
||||
}
|
||||
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(
|
||||
|
@ -709,6 +709,9 @@ array operator/(const array& a, const array& b);
|
||||
array operator/(double a, const array& b);
|
||||
array operator/(const array& a, double b);
|
||||
|
||||
/** Compute integer division. Equivalent to doing floor(a / x). */
|
||||
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the element-wise remainder of division */
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator%(const array& a, const array& b);
|
||||
|
@ -636,8 +636,7 @@ void init_array(py::module_& m) {
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(a, b), t);
|
||||
return floor_divide(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@ -650,8 +649,7 @@ void init_array(py::module_& m) {
|
||||
"__rfloordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(b, a), t);
|
||||
return floor_divide(b, a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
|
@ -303,6 +303,32 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The quotient ``a / b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"floor_divide",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return floor_divide(a, b, s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise integer division.
|
||||
|
||||
If either array is a floating point type then it is equivalent to
|
||||
calling :func:`floor` after :func:`divide`.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The quotient ``a // b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"remainder",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"floor_divide",
|
||||
"remainder",
|
||||
"equal",
|
||||
"not_equal",
|
||||
@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"floor_divide",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"power",
|
||||
@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"uint32",
|
||||
"uint64",
|
||||
]
|
||||
|
||||
float_dtypes = ["float16", "float32"]
|
||||
|
||||
dtypes = (
|
||||
float_dtypes
|
||||
if op in ("divide", "power")
|
||||
else (int_dtypes + float_dtypes)
|
||||
)
|
||||
dtypes = {
|
||||
"divide": float_dtypes,
|
||||
"power": float_dtypes,
|
||||
"floor_divide": ["float32"] + int_dtypes,
|
||||
}
|
||||
dtypes = dtypes.get(op, int_dtypes + float_dtypes)
|
||||
|
||||
for dtype in dtypes:
|
||||
atol = 1e-3 if dtype == "float16" else 1e-6
|
||||
with self.subTest(dtype=dtype):
|
||||
x1_ = x1.astype(getattr(np, dtype))
|
||||
x2_ = x2.astype(getattr(np, dtype))
|
||||
m = 10 if dtype in int_dtypes else 1
|
||||
x1_ = (x1 * m).astype(getattr(np, dtype))
|
||||
x2_ = (x2 * m).astype(getattr(np, dtype))
|
||||
y1_ = mx.array(x1_)
|
||||
y2_ = mx.array(x2_)
|
||||
test_ops(
|
||||
|
Loading…
Reference in New Issue
Block a user