Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -1027,17 +1027,6 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});
inline array operator^(const array& a, const array& b) {
return power(a, b);
}
template <typename T>
array operator^(T a, const array& b) {
return power(array(a), b);
}
template <typename T>
array operator^(const array& a, T b) {
return power(a, array(b));
}
/** Cumulative sum of an array. */
array cumsum(
@@ -1239,6 +1228,26 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
/** Bitwise and. */
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&(const array& a, const array& b);
/** Bitwise inclusive or. */
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
array operator|(const array& a, const array& b);
/** Bitwise exclusive or. */
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
array operator^(const array& a, const array& b);
/** Shift bits to the left. */
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator<<(const array& a, const array& b);
/** Shift bits to the right. */
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
/** @} */
} // namespace mlx::core