mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
67d1894759
commit
86f495985b
@ -28,8 +28,11 @@ Operations
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
broadcast_to
|
||||
bitwise_and
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
@ -69,6 +72,7 @@ Operations
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@ -105,6 +109,7 @@ Operations
|
||||
reciprocal
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
|
@ -236,4 +236,61 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -606,4 +606,39 @@ struct Select {
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@ -229,3 +229,38 @@ struct LogicalOr {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
@ -184,13 +184,7 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
@ -199,6 +193,15 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
@ -241,3 +244,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
|
||||
// Bitwise ops only need integer types and bool (except for l/r shift)
|
||||
instantiate_binary_integer(bitwise_and, BitwiseAnd)
|
||||
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
|
||||
instantiate_binary_integer(bitwise_or, BitwiseOr)
|
||||
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift)
|
||||
|
@ -533,6 +533,26 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
|
@ -45,7 +45,7 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract));
|
||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
|
73
mlx/ops.cpp
73
mlx/ops.cpp
@ -3944,4 +3944,77 @@ array number_of_elements(
|
||||
{a}));
|
||||
}
|
||||
|
||||
array bitwise_impl(
|
||||
const array& a,
|
||||
const array& b,
|
||||
BitwiseBinary::Op op,
|
||||
const std::string& op_name,
|
||||
const StreamOrDevice& s) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << op_name
|
||||
<< "] Only allowed on integer or boolean types "
|
||||
"but got types "
|
||||
<< a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
return array(
|
||||
a.shape(),
|
||||
out_type,
|
||||
std::make_shared<BitwiseBinary>(to_stream(s), op),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array bitwise_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return bitwise_impl(a, b, BitwiseBinary::Op::And, "bitwise_and", s);
|
||||
}
|
||||
array operator&(const array& a, const array& b) {
|
||||
return bitwise_and(a, b);
|
||||
}
|
||||
|
||||
array bitwise_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return bitwise_impl(a, b, BitwiseBinary::Op::Or, "bitwise_or", s);
|
||||
}
|
||||
array operator|(const array& a, const array& b) {
|
||||
return bitwise_or(a, b);
|
||||
}
|
||||
|
||||
array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return bitwise_impl(a, b, BitwiseBinary::Op::Xor, "bitwise_xor", s);
|
||||
}
|
||||
array operator^(const array& a, const array& b) {
|
||||
return bitwise_xor(a, b);
|
||||
}
|
||||
|
||||
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Bit shift on bool always up-casts to uint8
|
||||
auto t = promote_types(result_type(a, b), uint8);
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
BitwiseBinary::Op::LeftShift,
|
||||
"left_shift",
|
||||
s);
|
||||
}
|
||||
array operator<<(const array& a, const array& b) {
|
||||
return left_shift(a, b);
|
||||
}
|
||||
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Bit shift on bool always up-casts to uint8
|
||||
auto t = promote_types(result_type(a, b), uint8);
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
BitwiseBinary::Op::RightShift,
|
||||
"right_shift",
|
||||
s);
|
||||
}
|
||||
array operator>>(const array& a, const array& b) {
|
||||
return right_shift(a, b);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
31
mlx/ops.h
31
mlx/ops.h
@ -1027,17 +1027,6 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
|
||||
/** Raise elements of a to the power of b element-wise */
|
||||
array power(const array& a, const array& b, StreamOrDevice s = {});
|
||||
inline array operator^(const array& a, const array& b) {
|
||||
return power(a, b);
|
||||
}
|
||||
template <typename T>
|
||||
array operator^(T a, const array& b) {
|
||||
return power(array(a), b);
|
||||
}
|
||||
template <typename T>
|
||||
array operator^(const array& a, T b) {
|
||||
return power(a, array(b));
|
||||
}
|
||||
|
||||
/** Cumulative sum of an array. */
|
||||
array cumsum(
|
||||
@ -1239,6 +1228,26 @@ array number_of_elements(
|
||||
Dtype dtype = int32,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Bitwise and. */
|
||||
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator&(const array& a, const array& b);
|
||||
|
||||
/** Bitwise inclusive or. */
|
||||
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator|(const array& a, const array& b);
|
||||
|
||||
/** Bitwise exclusive or. */
|
||||
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator^(const array& a, const array& b);
|
||||
|
||||
/** Shift bits to the left. */
|
||||
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator<<(const array& a, const array& b);
|
||||
|
||||
/** Shift bits to the right. */
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator>>(const array& a, const array& b);
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -560,6 +560,44 @@ bool AsStrided::is_equivalent(const Primitive& other) const {
|
||||
offset_ == a_other.offset_;
|
||||
}
|
||||
|
||||
bool BitwiseBinary::is_equivalent(const Primitive& other) const {
|
||||
const BitwiseBinary& a_other = static_cast<const BitwiseBinary&>(other);
|
||||
return op_ == a_other.op_;
|
||||
}
|
||||
|
||||
void BitwiseBinary::print(std::ostream& os) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
os << "BitwiseAnd";
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
os << "BitwiseOr";
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
os << "BitwiseXor";
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
os << "LeftShift";
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
os << "RightShift";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {
|
||||
{array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<BitwiseBinary>(stream(), op_),
|
||||
{a, b})},
|
||||
{to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -443,6 +443,25 @@ class AsStrided : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BitwiseBinary : public UnaryPrimitive {
|
||||
public:
|
||||
enum Op { And, Or, Xor, LeftShift, RightShift };
|
||||
|
||||
explicit BitwiseBinary(Stream stream, Op op)
|
||||
: UnaryPrimitive(stream), op_(op){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
void print(std::ostream& os) override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
Op op_;
|
||||
};
|
||||
|
||||
class BlockMaskedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||
|
@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
return bitwise_and(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
a.overwrite_descriptor(logical_and(a, b));
|
||||
a.overwrite_descriptor(bitwise_and(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
return bitwise_or(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
a.overwrite_descriptor(bitwise_or(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__lshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("left shift", v);
|
||||
}
|
||||
a.overwrite_descriptor(logical_or(a, b));
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with left shift.");
|
||||
}
|
||||
return left_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ilshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace left shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or left shift.");
|
||||
}
|
||||
a.overwrite_descriptor(left_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__rshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with right shift.");
|
||||
}
|
||||
return right_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__irshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or right shift.");
|
||||
}
|
||||
a.overwrite_descriptor(right_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
|
@ -3702,8 +3702,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
||||
@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) {
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"bitwise_and",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_and(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise and.
|
||||
|
||||
Take the bitwise and of two arrays with numpy-style broadcasting
|
||||
semantics. Either or both input arrays can also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise and ``a & b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bitwise_or",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_or(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise or.
|
||||
|
||||
Take the bitwise or of two arrays with numpy-style broadcasting
|
||||
semantics. Either or both input arrays can also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise or``a | b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bitwise_xor",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_xor(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise xor.
|
||||
|
||||
Take the bitwise exclusive or of two arrays with numpy-style
|
||||
broadcasting semantics. Either or both input arrays can also be
|
||||
scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise xor ``a ^ b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"left_shift",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return left_shift(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise left shift.
|
||||
|
||||
Shift the bits of the first input to the left by the second using
|
||||
numpy-style broadcasting semantics. Either or both input arrays can
|
||||
also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise left shift ``a << b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"right_shift",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return right_shift(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise right shift.
|
||||
|
||||
Shift the bits of the first input to the right by the second using
|
||||
numpy-style broadcasting semantics. Either or both input arrays can
|
||||
also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise right shift ``a >> b``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
f"mx and np don't aggree on {a}, {b}",
|
||||
)
|
||||
|
||||
def test_bitwise_ops(self):
|
||||
types = [
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
]
|
||||
a = mx.random.randint(0, 4096, (1000,))
|
||||
b = mx.random.randint(0, 4096, (1000,))
|
||||
for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
|
||||
for t in types:
|
||||
a_mlx = a.astype(t)
|
||||
b_mlx = b.astype(t)
|
||||
a_np = np.array(a_mlx)
|
||||
b_np = np.array(b_mlx)
|
||||
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
for op in ["left_shift", "right_shift"]:
|
||||
for t in types:
|
||||
a_mlx = a.astype(t)
|
||||
b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t)
|
||||
a_np = np.array(a_mlx)
|
||||
b_np = np.array(b_mlx)
|
||||
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -596,7 +596,7 @@ TEST_CASE("test op vjps") {
|
||||
// Test power
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{inputs[0] ^ inputs[1]};
|
||||
return std::vector<array>{power(inputs[0], inputs[1])};
|
||||
};
|
||||
auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second;
|
||||
CHECK_EQ(out[0].item<float>(), 48.0f);
|
||||
|
@ -2308,29 +2308,26 @@ TEST_CASE("test pad") {
|
||||
|
||||
TEST_CASE("test power") {
|
||||
CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
|
||||
CHECK_EQ((array(1) ^ 2).item<int>(), 1);
|
||||
CHECK_EQ((1 ^ array(2)).item<int>(), 1);
|
||||
CHECK_EQ((array(-1) ^ 2).item<int>(), 1);
|
||||
CHECK_EQ((array(-1) ^ 3).item<int>(), -1);
|
||||
CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
|
||||
CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);
|
||||
|
||||
// TODO Throws but exception not caught from calling thread
|
||||
// CHECK_THROWS((x^-1).item<int>());
|
||||
|
||||
CHECK_EQ((array(true) ^ array(false)).item<bool>(), true);
|
||||
CHECK_EQ((array(false) ^ array(false)).item<bool>(), true);
|
||||
CHECK_EQ((array(true) ^ array(true)).item<bool>(), true);
|
||||
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
|
||||
CHECK_EQ((power(array(true), array(false))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(false), array(false))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(true), array(true))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(false), array(true))).item<bool>(), false);
|
||||
|
||||
auto x = array(2.0f);
|
||||
CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
|
||||
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
|
||||
CHECK_EQ(
|
||||
(power(x, array(0.5))).item<float>(),
|
||||
doctest::Approx(std::pow(2.0f, 0.5f)));
|
||||
CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);
|
||||
|
||||
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));
|
||||
CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));
|
||||
|
||||
auto a = complex64_t{0.5, 0.5};
|
||||
auto b = complex64_t{0.5, 0.5};
|
||||
auto expected = std::pow(a, b);
|
||||
auto out = (array(a) ^ array(b)).item<complex64_t>();
|
||||
auto out = (power(array(a), array(b))).item<complex64_t>();
|
||||
CHECK(abs(out.real() - expected.real()) < 1e-7);
|
||||
CHECK(abs(out.imag() - expected.imag()) < 1e-7);
|
||||
}
|
||||
@ -3230,4 +3227,4 @@ TEST_CASE("test meshgrid") {
|
||||
expected_one = array({1, 2, 3}, {3, 1});
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user