Add the remainder op (#85)

* Add remainder in the C++ backend
* Add the python binding and test
This commit is contained in:
Angelos Katharopoulos
2023-12-08 15:08:52 -08:00
committed by GitHub
parent 69a24e6a1e
commit 2b714714e1
14 changed files with 229 additions and 0 deletions

View File

@@ -624,6 +624,18 @@ void init_array(py::module_& m) {
return divide(to_array(v, float32), a);
},
"other"_a)
.def(
"__mod__",
[](const array& a, const ScalarOrArray v) {
return remainder(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rmod__",
[](const array& a, const ScalarOrArray v) {
return remainder(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__eq__",
[](const array& a, const ScalarOrArray v) {

View File

@@ -253,6 +253,31 @@ void init_ops(py::module_& m) {
Returns:
array: The quotient ``a / b``.
)pbdoc");
m.def(
"remainder",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return remainder(a, b, s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Element-wise remainder of division.
Computes the remainder of dividing a with b with numpy-style
broadcasting semantics. Either or both input arrays can also be
scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The remainder of ``a // b``.
)pbdoc");
m.def(
"equal",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract",
"multiply",
"divide",
"remainder",
"equal",
"not_equal",
"less",
@@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 0.5)
def test_remainder(self):
for dt in [mx.int32, mx.float32]:
x = mx.array(2, dtype=dt)
y = mx.array(4, dtype=dt)
z1 = mx.remainder(x, y)
z2 = mx.remainder(y, x)
self.assertEqual(z1.dtype, dt)
self.assertEqual(z1.item(), 2)
self.assertEqual(z2.item(), 0)
z = x % 4
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 2)
z = 1 % x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1)
def test_comparisons(self):
a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0])