mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make cpu binary_op easily accessible (#2733)
This commit is contained in:
committed by
GitHub
parent
d3bc6a9bff
commit
6ece97f69b
@@ -14,233 +14,11 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void comparison_op(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_float(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[binary_float] Only supports floating point types.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_int(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("[binary_int] Type not supported");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Add(), stream());
|
binary_op_cpu(a, b, out, detail::Add(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval_cpu(
|
void DivMod::eval_cpu(
|
||||||
@@ -324,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Divide(), stream());
|
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Remainder(), stream());
|
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -372,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
comparison_op(a, b, out, detail::Equal(), stream());
|
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
comparison_op_cpu(
|
||||||
|
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Maximum(), stream());
|
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Minimum(), stream());
|
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Multiply(), stream());
|
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Power(), stream());
|
binary_op_cpu(a, b, out, detail::Power(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Subtract(), stream());
|
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -463,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
binary_int(a, b, out, detail::RightShift(), stream());
|
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -484,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
|||||||
binary_op<T, T, Op>(a, b, out, bopt);
|
binary_op<T, T, Op>(a, b, out, bopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void comparison_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_float_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[binary_float] Only supports floating point types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_int_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[binary_int] Type not supported");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -147,37 +147,8 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
copy_cpu(c, out, ctype, stream());
|
copy_cpu(c, out, ctype, stream());
|
||||||
} else {
|
} else {
|
||||||
array beta_scalar = array(beta_, c.dtype());
|
array beta_scalar = array(beta_, c.dtype());
|
||||||
auto bopt = get_binary_op_type(c, beta_scalar);
|
|
||||||
set_binary_op_output_data(c, beta_scalar, out, bopt);
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(c);
|
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
|
||||||
encoder.set_input_array(beta_scalar);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([c = array::unsafe_weak_copy(c),
|
|
||||||
beta_scalar = array::unsafe_weak_copy(beta_scalar),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, detail::Multiply>(c, beta_scalar, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, detail::Multiply>(c, beta_scalar, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[AddMM::eval_cpu] Unsupported dtype for beta scaling");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
encoder.add_temporary(std::move(beta_scalar));
|
encoder.add_temporary(std::move(beta_scalar));
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
|||||||
Reference in New Issue
Block a user