mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 05:31:18 +08:00
Add the remainder op (#85)
* Add remainder in the C++ backend * Add the python binding and test
This commit is contained in:
parent
69a24e6a1e
commit
2b714714e1
@ -322,6 +322,45 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Avoid code duplication with the common backend.
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
RemainderFn{},
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
int num_el = n;
|
||||
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@ -82,6 +82,29 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
}
|
||||
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
|
@ -35,6 +35,7 @@ DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
|
@ -14,6 +14,13 @@ struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat operator()(bfloat x, bfloat y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
@ -363,6 +370,7 @@ instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
|
@ -110,3 +110,7 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
return {fmod(a.real, b.real), fmod(a.imag, b.imag)};
|
||||
}
|
||||
|
@ -363,6 +363,10 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
|
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -1438,6 +1438,20 @@ array operator/(const array& a, double b) {
|
||||
return divide(a, array(b));
|
||||
}
|
||||
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
return array(
|
||||
inputs[0].shape(),
|
||||
dtype,
|
||||
std::make_unique<Remainder>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
array operator%(const array& a, const array& b) {
|
||||
return remainder(a, b);
|
||||
}
|
||||
|
||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
|
12
mlx/ops.h
12
mlx/ops.h
@ -636,6 +636,18 @@ array operator/(const array& a, const array& b);
|
||||
array operator/(double a, const array& b);
|
||||
array operator/(const array& a, double b);
|
||||
|
||||
/** Compute the element-wise remainder of division */
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator%(const array& a, const array& b);
|
||||
template <typename T>
|
||||
array operator%(T a, const array& b) {
|
||||
return remainder(array(a), b);
|
||||
}
|
||||
template <typename T>
|
||||
array operator%(const array& a, T b) {
|
||||
return remainder(a, array(b));
|
||||
}
|
||||
|
||||
/** Element-wise maximum between two arrays. */
|
||||
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
|
@ -738,6 +738,53 @@ std::pair<array, int> Divide::vmap(
|
||||
return {divide(a, b, stream()), to_ax};
|
||||
}
|
||||
|
||||
std::vector<array> Remainder::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(cotan);
|
||||
} else {
|
||||
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||
// TODO: Replace with a proper floor when available
|
||||
x_over_y = astype(x_over_y, int32, stream());
|
||||
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
array Remainder::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto jvp_fun = [&](int i) {
|
||||
int arg = argnums[i];
|
||||
if (arg == 0) {
|
||||
return tangents[i];
|
||||
} else {
|
||||
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||
// TODO: Replace with a proper floor when available
|
||||
x_over_y = astype(x_over_y, int32, stream());
|
||||
return negative(multiply(x_over_y, tangents[i], stream()), stream());
|
||||
}
|
||||
};
|
||||
auto out = jvp_fun(0);
|
||||
if (argnums.size() > 1) {
|
||||
out = add(out, jvp_fun(1), stream());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::pair<array, int> Remainder::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {remainder(a, b, stream()), to_ax};
|
||||
}
|
||||
|
||||
std::pair<array, int> Equal::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@ -536,6 +536,25 @@ class Divide : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Remainder : public Primitive {
|
||||
public:
|
||||
explicit Remainder(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Remainder)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Equal : public Primitive {
|
||||
public:
|
||||
explicit Equal(Stream stream, bool equal_nan = false)
|
||||
|
@ -624,6 +624,18 @@ void init_array(py::module_& m) {
|
||||
return divide(to_array(v, float32), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__mod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return remainder(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rmod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return remainder(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
|
@ -253,6 +253,31 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The quotient ``a / b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"remainder",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return remainder(a, b, s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Element-wise remainder of division.
|
||||
|
||||
Computes the remainder of dividing a with b with numpy-style
|
||||
broadcasting semantics. Either or both input arrays can also be
|
||||
scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The remainder of ``a // b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"equal",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"remainder",
|
||||
"equal",
|
||||
"not_equal",
|
||||
"less",
|
||||
@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(z.dtype, mx.float32)
|
||||
self.assertEqual(z.item(), 0.5)
|
||||
|
||||
def test_remainder(self):
|
||||
for dt in [mx.int32, mx.float32]:
|
||||
x = mx.array(2, dtype=dt)
|
||||
y = mx.array(4, dtype=dt)
|
||||
|
||||
z1 = mx.remainder(x, y)
|
||||
z2 = mx.remainder(y, x)
|
||||
self.assertEqual(z1.dtype, dt)
|
||||
self.assertEqual(z1.item(), 2)
|
||||
self.assertEqual(z2.item(), 0)
|
||||
|
||||
z = x % 4
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = 1 % x
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 1)
|
||||
|
||||
def test_comparisons(self):
|
||||
a = mx.array([0.0, 1.0, 5.0])
|
||||
b = mx.array([-1.0, 2.0, 5.0])
|
||||
|
Loading…
Reference in New Issue
Block a user