Add the remainder op (#85)

* Add remainder in the C++ backend
* Add the python binding and test
This commit is contained in:
Angelos Katharopoulos 2023-12-08 15:08:52 -08:00 committed by GitHub
parent 69a24e6a1e
commit 2b714714e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 229 additions and 0 deletions

View File

@ -322,6 +322,45 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@ -82,6 +82,29 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {

View File

@ -35,6 +35,7 @@ DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)

View File

@ -14,6 +14,13 @@ struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat operator()(bfloat x, bfloat y) { return fmod(x, y); }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
@ -363,6 +370,7 @@ instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)

View File

@ -110,3 +110,7 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
return {fmod(a.real, b.real), fmod(a.imag, b.imag)};
}

View File

@ -363,6 +363,10 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}

View File

@ -30,6 +30,7 @@ NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)

View File

@ -1438,6 +1438,20 @@ array operator/(const array& a, double b) {
return divide(a, array(b));
}
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
return array(
inputs[0].shape(),
dtype,
std::make_unique<Remainder>(to_stream(s)),
inputs);
}
array operator%(const array& a, const array& b) {
return remainder(a, b);
}
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =

View File

@ -636,6 +636,18 @@ array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b);
template <typename T>
array operator%(T a, const array& b) {
return remainder(array(a), b);
}
template <typename T>
array operator%(const array& a, T b) {
return remainder(a, array(b));
}
/** Element-wise maximum between two arrays. */
array maximum(const array& a, const array& b, StreamOrDevice s = {});

View File

@ -738,6 +738,53 @@ std::pair<array, int> Divide::vmap(
return {divide(a, b, stream()), to_ax};
}
std::vector<array> Remainder::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(cotan);
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
}
}
return vjps;
}
array Remainder::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto jvp_fun = [&](int i) {
int arg = argnums[i];
if (arg == 0) {
return tangents[i];
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
return negative(multiply(x_over_y, tangents[i], stream()), stream());
}
};
auto out = jvp_fun(0);
if (argnums.size() > 1) {
out = add(out, jvp_fun(1), stream());
}
return out;
}
std::pair<array, int> Remainder::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {remainder(a, b, stream()), to_ax};
}
std::pair<array, int> Equal::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@ -536,6 +536,25 @@ class Divide : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Remainder : public Primitive {
public:
explicit Remainder(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Equal : public Primitive {
public:
explicit Equal(Stream stream, bool equal_nan = false)

View File

@ -624,6 +624,18 @@ void init_array(py::module_& m) {
return divide(to_array(v, float32), a);
},
"other"_a)
.def(
"__mod__",
[](const array& a, const ScalarOrArray v) {
return remainder(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rmod__",
[](const array& a, const ScalarOrArray v) {
return remainder(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__eq__",
[](const array& a, const ScalarOrArray v) {

View File

@ -253,6 +253,31 @@ void init_ops(py::module_& m) {
Returns:
array: The quotient ``a / b``.
)pbdoc");
m.def(
"remainder",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return remainder(a, b, s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Element-wise remainder of division.
Computes the remainder of dividing a with b with numpy-style
broadcasting semantics. Either or both input arrays can also be
scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The remainder of ``a // b``.
)pbdoc");
m.def(
"equal",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract",
"multiply",
"divide",
"remainder",
"equal",
"not_equal",
"less",
@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 0.5)
def test_remainder(self):
for dt in [mx.int32, mx.float32]:
x = mx.array(2, dtype=dt)
y = mx.array(4, dtype=dt)
z1 = mx.remainder(x, y)
z2 = mx.remainder(y, x)
self.assertEqual(z1.dtype, dt)
self.assertEqual(z1.item(), 2)
self.assertEqual(z2.item(), 0)
z = x % 4
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 2)
z = 1 % x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1)
def test_comparisons(self):
a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0])