mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
@@ -229,3 +229,38 @@ struct LogicalOr {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
@@ -184,13 +184,7 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
@@ -199,6 +193,15 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
@@ -241,3 +244,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
|
||||
// Bitwise ops only need integer types and bool (except for l/r shift)
|
||||
instantiate_binary_integer(bitwise_and, BitwiseAnd)
|
||||
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
|
||||
instantiate_binary_integer(bitwise_or, BitwiseOr)
|
||||
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift)
|
||||
|
@@ -533,6 +533,26 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
Reference in New Issue
Block a user