mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide * Add floor_divide to the tests * Add floor_divide to the docs
This commit is contained in:

committed by
GitHub

parent
de892cb66c
commit
2807c6aff0
@@ -636,8 +636,7 @@ void init_array(py::module_& m) {
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(a, b), t);
|
||||
return floor_divide(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -650,8 +649,7 @@ void init_array(py::module_& m) {
|
||||
"__rfloordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(b, a), t);
|
||||
return floor_divide(b, a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
|
@@ -303,6 +303,32 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The quotient ``a / b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"floor_divide",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return floor_divide(a, b, s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise integer division.
|
||||
|
||||
If either array is a floating point type then it is equivalent to
|
||||
calling :func:`floor` after :func:`divide`.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The quotient ``a // b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"remainder",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
Reference in New Issue
Block a user