Add the remainder op (#85)

* Add remainder in the C++ backend
* Add the python binding and test
This commit is contained in:
Angelos Katharopoulos
2023-12-08 15:08:52 -08:00
committed by GitHub
parent 69a24e6a1e
commit 2b714714e1
14 changed files with 229 additions and 0 deletions

View File

@@ -322,6 +322,45 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -82,6 +82,29 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {

View File

@@ -35,6 +35,7 @@ DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)

View File

@@ -14,6 +14,13 @@ struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat operator()(bfloat x, bfloat y) { return fmod(x, y); }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
@@ -363,6 +370,7 @@ instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)

View File

@@ -110,3 +110,7 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
return {fmod(a.real, b.real), fmod(a.imag, b.imag)};
}

View File

@@ -363,6 +363,10 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}

View File

@@ -30,6 +30,7 @@ NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)

View File

@@ -1438,6 +1438,20 @@ array operator/(const array& a, double b) {
return divide(a, array(b));
}
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
return array(
inputs[0].shape(),
dtype,
std::make_unique<Remainder>(to_stream(s)),
inputs);
}
array operator%(const array& a, const array& b) {
return remainder(a, b);
}
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =

View File

@@ -636,6 +636,18 @@ array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b);
template <typename T>
array operator%(T a, const array& b) {
return remainder(array(a), b);
}
template <typename T>
array operator%(const array& a, T b) {
return remainder(a, array(b));
}
/** Element-wise maximum between two arrays. */
array maximum(const array& a, const array& b, StreamOrDevice s = {});

View File

@@ -738,6 +738,53 @@ std::pair<array, int> Divide::vmap(
return {divide(a, b, stream()), to_ax};
}
std::vector<array> Remainder::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(cotan);
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
}
}
return vjps;
}
array Remainder::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto jvp_fun = [&](int i) {
int arg = argnums[i];
if (arg == 0) {
return tangents[i];
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
return negative(multiply(x_over_y, tangents[i], stream()), stream());
}
};
auto out = jvp_fun(0);
if (argnums.size() > 1) {
out = add(out, jvp_fun(1), stream());
}
return out;
}
std::pair<array, int> Remainder::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {remainder(a, b, stream()), to_ax};
}
std::pair<array, int> Equal::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@@ -536,6 +536,25 @@ class Divide : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Remainder : public Primitive {
public:
explicit Remainder(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Equal : public Primitive {
public:
explicit Equal(Stream stream, bool equal_nan = false)