mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
Add the remainder op (#85)
* Add remainder in the C++ backend * Add the python binding and test
This commit is contained in:
parent
69a24e6a1e
commit
2b714714e1
@ -322,6 +322,45 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Avoid code duplication with the common backend.
|
||||||
|
struct RemainderFn {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return std::fmod(numerator, denominator);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (a.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
RemainderFn{},
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
int num_el = n;
|
||||||
|
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, RemainderFn{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -82,6 +82,29 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RemainderFn {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return std::fmod(numerator, denominator);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, RemainderFn{});
|
||||||
|
}
|
||||||
|
|
||||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
|
@ -35,6 +35,7 @@ DEFAULT(Copy)
|
|||||||
DEFAULT(Cos)
|
DEFAULT(Cos)
|
||||||
DEFAULT(Cosh)
|
DEFAULT(Cosh)
|
||||||
DEFAULT(Divide)
|
DEFAULT(Divide)
|
||||||
|
DEFAULT(Remainder)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
|
@ -14,6 +14,13 @@ struct Divide {
|
|||||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Remainder {
|
||||||
|
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||||
|
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||||
|
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||||
|
template <> bfloat operator()(bfloat x, bfloat y) { return fmod(x, y); }
|
||||||
|
};
|
||||||
|
|
||||||
struct Equal {
|
struct Equal {
|
||||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||||
};
|
};
|
||||||
@ -363,6 +370,7 @@ instantiate_binary_types(min, Minimum)
|
|||||||
instantiate_binary_types(mul, Multiply)
|
instantiate_binary_types(mul, Multiply)
|
||||||
instantiate_binary_types(sub, Subtract)
|
instantiate_binary_types(sub, Subtract)
|
||||||
instantiate_binary_types(pow, Power)
|
instantiate_binary_types(pow, Power)
|
||||||
|
instantiate_binary_types(rem, Remainder)
|
||||||
|
|
||||||
// NaNEqual only needed for floating point types with boolean output
|
// NaNEqual only needed for floating point types with boolean output
|
||||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||||
|
@ -110,3 +110,7 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
|||||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
|
return {fmod(a.real, b.real), fmod(a.imag, b.imag)};
|
||||||
|
}
|
||||||
|
@ -363,6 +363,10 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
binary_op(inputs, out, "div");
|
binary_op(inputs, out, "div");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "rem");
|
||||||
|
}
|
||||||
|
|
||||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||||
}
|
}
|
||||||
|
@ -30,6 +30,7 @@ NO_GPU(Copy)
|
|||||||
NO_GPU(Cos)
|
NO_GPU(Cos)
|
||||||
NO_GPU(Cosh)
|
NO_GPU(Cosh)
|
||||||
NO_GPU(Divide)
|
NO_GPU(Divide)
|
||||||
|
NO_GPU(Remainder)
|
||||||
NO_GPU(Equal)
|
NO_GPU(Equal)
|
||||||
NO_GPU(Erf)
|
NO_GPU(Erf)
|
||||||
NO_GPU(ErfInv)
|
NO_GPU(ErfInv)
|
||||||
|
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -1438,6 +1438,20 @@ array operator/(const array& a, double b) {
|
|||||||
return divide(a, array(b));
|
return divide(a, array(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
|
auto inputs = broadcast_arrays(
|
||||||
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||||
|
return array(
|
||||||
|
inputs[0].shape(),
|
||||||
|
dtype,
|
||||||
|
std::make_unique<Remainder>(to_stream(s)),
|
||||||
|
inputs);
|
||||||
|
}
|
||||||
|
array operator%(const array& a, const array& b) {
|
||||||
|
return remainder(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
|
12
mlx/ops.h
12
mlx/ops.h
@ -636,6 +636,18 @@ array operator/(const array& a, const array& b);
|
|||||||
array operator/(double a, const array& b);
|
array operator/(double a, const array& b);
|
||||||
array operator/(const array& a, double b);
|
array operator/(const array& a, double b);
|
||||||
|
|
||||||
|
/** Compute the element-wise remainder of division */
|
||||||
|
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator%(const array& a, const array& b);
|
||||||
|
template <typename T>
|
||||||
|
array operator%(T a, const array& b) {
|
||||||
|
return remainder(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator%(const array& a, T b) {
|
||||||
|
return remainder(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
/** Element-wise maximum between two arrays. */
|
/** Element-wise maximum between two arrays. */
|
||||||
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -738,6 +738,53 @@ std::pair<array, int> Divide::vmap(
|
|||||||
return {divide(a, b, stream()), to_ax};
|
return {divide(a, b, stream()), to_ax};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Remainder::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
std::vector<array> vjps;
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
if (arg == 0) {
|
||||||
|
vjps.push_back(cotan);
|
||||||
|
} else {
|
||||||
|
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||||
|
// TODO: Replace with a proper floor when available
|
||||||
|
x_over_y = astype(x_over_y, int32, stream());
|
||||||
|
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
array Remainder::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto jvp_fun = [&](int i) {
|
||||||
|
int arg = argnums[i];
|
||||||
|
if (arg == 0) {
|
||||||
|
return tangents[i];
|
||||||
|
} else {
|
||||||
|
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||||
|
// TODO: Replace with a proper floor when available
|
||||||
|
x_over_y = astype(x_over_y, int32, stream());
|
||||||
|
return negative(multiply(x_over_y, tangents[i], stream()), stream());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto out = jvp_fun(0);
|
||||||
|
if (argnums.size() > 1) {
|
||||||
|
out = add(out, jvp_fun(1), stream());
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, int> Remainder::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
|
return {remainder(a, b, stream()), to_ax};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<array, int> Equal::vmap(
|
std::pair<array, int> Equal::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -536,6 +536,25 @@ class Divide : public Primitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Remainder : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Remainder(Stream stream) : Primitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Remainder)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Equal : public Primitive {
|
class Equal : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Equal(Stream stream, bool equal_nan = false)
|
explicit Equal(Stream stream, bool equal_nan = false)
|
||||||
|
@ -624,6 +624,18 @@ void init_array(py::module_& m) {
|
|||||||
return divide(to_array(v, float32), a);
|
return divide(to_array(v, float32), a);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__mod__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return remainder(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__rmod__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return remainder(to_array(v, a.dtype()), a);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__eq__",
|
"__eq__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
@ -253,6 +253,31 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The quotient ``a / b``.
|
array: The quotient ``a / b``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"remainder",
|
||||||
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||||
|
auto [a, b] = to_arrays(a_, b_);
|
||||||
|
return remainder(a, b, s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
Element-wise remainder of division.
|
||||||
|
|
||||||
|
Computes the remainder of dividing a with b with numpy-style
|
||||||
|
broadcasting semantics. Either or both input arrays can also be
|
||||||
|
scalars.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array or scalar.
|
||||||
|
b (array): Input array or scalar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The remainder of ``a // b``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"equal",
|
"equal",
|
||||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||||
|
@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
"subtract",
|
"subtract",
|
||||||
"multiply",
|
"multiply",
|
||||||
"divide",
|
"divide",
|
||||||
|
"remainder",
|
||||||
"equal",
|
"equal",
|
||||||
"not_equal",
|
"not_equal",
|
||||||
"less",
|
"less",
|
||||||
@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(z.dtype, mx.float32)
|
self.assertEqual(z.dtype, mx.float32)
|
||||||
self.assertEqual(z.item(), 0.5)
|
self.assertEqual(z.item(), 0.5)
|
||||||
|
|
||||||
|
def test_remainder(self):
|
||||||
|
for dt in [mx.int32, mx.float32]:
|
||||||
|
x = mx.array(2, dtype=dt)
|
||||||
|
y = mx.array(4, dtype=dt)
|
||||||
|
|
||||||
|
z1 = mx.remainder(x, y)
|
||||||
|
z2 = mx.remainder(y, x)
|
||||||
|
self.assertEqual(z1.dtype, dt)
|
||||||
|
self.assertEqual(z1.item(), 2)
|
||||||
|
self.assertEqual(z2.item(), 0)
|
||||||
|
|
||||||
|
z = x % 4
|
||||||
|
self.assertEqual(z.dtype, dt)
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
z = 1 % x
|
||||||
|
self.assertEqual(z.dtype, dt)
|
||||||
|
self.assertEqual(z.item(), 1)
|
||||||
|
|
||||||
def test_comparisons(self):
|
def test_comparisons(self):
|
||||||
a = mx.array([0.0, 1.0, 5.0])
|
a = mx.array([0.0, 1.0, 5.0])
|
||||||
b = mx.array([-1.0, 2.0, 5.0])
|
b = mx.array([-1.0, 2.0, 5.0])
|
||||||
|
Loading…
Reference in New Issue
Block a user