mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 04:31:48 +08:00
Add the remainder op (#85)
* Add remainder in the C++ backend * Add the python binding and test
This commit is contained in:

committed by
GitHub

parent
69a24e6a1e
commit
2b714714e1
@@ -624,6 +624,18 @@ void init_array(py::module_& m) {
|
||||
return divide(to_array(v, float32), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__mod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return remainder(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rmod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return remainder(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
|
@@ -253,6 +253,31 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The quotient ``a / b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"remainder",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return remainder(a, b, s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Element-wise remainder of division.
|
||||
|
||||
Computes the remainder of dividing a with b with numpy-style
|
||||
broadcasting semantics. Either or both input arrays can also be
|
||||
scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The remainder of ``a // b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"equal",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
Reference in New Issue
Block a user