Implements divide for integer types and adds floor_divide op (#228)

* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
This commit is contained in:
Angelos Katharopoulos
2023-12-19 20:12:19 -08:00
committed by GitHub
parent de892cb66c
commit 2807c6aff0
8 changed files with 67 additions and 14 deletions

View File

@@ -357,7 +357,7 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add)
instantiate_binary_float(div, Divide)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual)

View File

@@ -111,6 +111,13 @@ constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a.real * b.real + a.imag * b.imag;
auto y = a.imag * b.real - a.real * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@@ -1636,6 +1636,20 @@ array operator/(const array& a, double b) {
return divide(a, array(b));
}
array floor_divide(
const array& a,
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) {
return floor(divide(a, b, s), s);
}
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
}
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(

View File

@@ -709,6 +709,9 @@ array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Compute integer division. Equivalent to doing floor(a / x). */
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
/** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b);