diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 35aa2a3e0..c1db2e118 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -14,233 +14,11 @@ namespace mlx::core { -namespace { - -template -void binary(const array& a, const array& b, array& out, Op op, Stream stream) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - encoder.dispatch([a = array::unsafe_weak_copy(a), - b = array::unsafe_weak_copy(b), - out = array::unsafe_weak_copy(out), - bopt]() mutable { - switch (out.dtype()) { - case bool_: - binary_op(a, b, out, bopt); - break; - case uint8: - binary_op(a, b, out, bopt); - break; - case uint16: - binary_op(a, b, out, bopt); - break; - case uint32: - binary_op(a, b, out, bopt); - break; - case uint64: - binary_op(a, b, out, bopt); - break; - case int8: - binary_op(a, b, out, bopt); - break; - case int16: - binary_op(a, b, out, bopt); - break; - case int32: - binary_op(a, b, out, bopt); - break; - case int64: - binary_op(a, b, out, bopt); - break; - case float16: - binary_op(a, b, out, bopt); - break; - case float32: - binary_op(a, b, out, bopt); - break; - case float64: - binary_op(a, b, out, bopt); - break; - case bfloat16: - binary_op(a, b, out, bopt); - break; - case complex64: - binary_op(a, b, out, bopt); - break; - } - }); -} - -template -void comparison_op( - const array& a, - const array& b, - array& out, - Op op, - Stream stream) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - encoder.dispatch([a = array::unsafe_weak_copy(a), - b = array::unsafe_weak_copy(b), - out = array::unsafe_weak_copy(out), - bopt]() mutable { - switch (a.dtype()) { - case bool_: - binary_op(a, b, out, bopt); - break; - case uint8: - binary_op(a, b, out, bopt); - break; - case uint16: - binary_op(a, b, out, bopt); - break; - case uint32: - binary_op(a, b, out, bopt); - break; - case uint64: - binary_op(a, b, out, bopt); - break; - case int8: - binary_op(a, b, out, bopt); - break; - case int16: - binary_op(a, b, out, bopt); - break; - case int32: - binary_op(a, b, out, bopt); - break; - case int64: - binary_op(a, b, out, bopt); - break; - case float16: - binary_op(a, b, out, bopt); - break; - case float32: - binary_op(a, b, out, bopt); - break; - case float64: - binary_op(a, b, out, bopt); - break; - case bfloat16: - binary_op(a, b, out, bopt); - break; - case complex64: - binary_op(a, b, out, bopt); - break; - } - }); -} - -template -void binary_float( - const array& a, - const array& b, - array& out, - Op op, - Stream stream) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - encoder.dispatch([a = array::unsafe_weak_copy(a), - b = array::unsafe_weak_copy(b), - out = array::unsafe_weak_copy(out), - bopt]() mutable { - switch (out.dtype()) { - case float16: - binary_op(a, b, out, bopt); - break; - case float32: - binary_op(a, b, out, bopt); - break; - case float64: - binary_op(a, b, out, bopt); - break; - case bfloat16: - binary_op(a, b, out, bopt); - break; - case complex64: - binary_op(a, b, out, bopt); - break; - default: - throw std::runtime_error( - "[binary_float] Only supports floating point types."); - } - }); -} - -template -void binary_int( - const array& a, - const array& b, - array& out, - Op op, - Stream stream) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - encoder.dispatch([a = array::unsafe_weak_copy(a), - b = array::unsafe_weak_copy(b), - out = array::unsafe_weak_copy(out), - bopt]() mutable { - switch (out.dtype()) { - case bool_: - binary_op(a, b, out, bopt); - case uint8: - binary_op(a, b, out, bopt); - break; - case uint16: - binary_op(a, b, out, bopt); - break; - case uint32: - binary_op(a, b, out, bopt); - break; - case uint64: - binary_op(a, b, out, bopt); - break; - case int8: - binary_op(a, b, out, bopt); - break; - case int16: - binary_op(a, b, out, bopt); - break; - case int32: - binary_op(a, b, out, bopt); - break; - case int64: - binary_op(a, b, out, bopt); - break; - default: - throw std::runtime_error("[binary_int] Type not supported"); - break; - } - }); -} - -} // namespace - void Add::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Add(), stream()); + binary_op_cpu(a, b, out, detail::Add(), stream()); } void DivMod::eval_cpu( @@ -324,14 +102,14 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Divide(), stream()); + binary_op_cpu(a, b, out, detail::Divide(), stream()); } void Remainder::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Remainder(), stream()); + binary_op_cpu(a, b, out, detail::Remainder(), stream()); } void Equal::eval_cpu(const std::vector& inputs, array& out) { @@ -372,89 +150,90 @@ void Equal::eval_cpu(const std::vector& inputs, array& out) { } }); } else { - comparison_op(a, b, out, detail::Equal(), stream()); + comparison_op_cpu(a, b, out, detail::Equal(), stream()); } } void Greater::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream()); + comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream()); } void GreaterEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream()); + comparison_op_cpu( + inputs[0], inputs[1], out, detail::GreaterEqual(), stream()); } void Less::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::Less(), stream()); + comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream()); } void LessEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream()); + comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream()); } void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary_float(a, b, out, detail::LogAddExp(), stream()); + binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream()); } void LogicalAnd::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalAnd(), stream()); + binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream()); } void LogicalOr::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalOr(), stream()); + binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream()); } void Maximum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Maximum(), stream()); + binary_op_cpu(a, b, out, detail::Maximum(), stream()); } void Minimum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Minimum(), stream()); + binary_op_cpu(a, b, out, detail::Minimum(), stream()); } void Multiply::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Multiply(), stream()); + binary_op_cpu(a, b, out, detail::Multiply(), stream()); } void NotEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream()); + comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream()); } void Power::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Power(), stream()); + binary_op_cpu(a, b, out, detail::Power(), stream()); } void Subtract::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Subtract(), stream()); + binary_op_cpu(a, b, out, detail::Subtract(), stream()); } void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { @@ -463,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; switch (op_) { case BitwiseBinary::And: - binary_int(a, b, out, detail::BitwiseAnd(), stream()); + binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream()); break; case BitwiseBinary::Or: - binary_int(a, b, out, detail::BitwiseOr(), stream()); + binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream()); break; case BitwiseBinary::Xor: - binary_int(a, b, out, detail::BitwiseXor(), stream()); + binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream()); break; case BitwiseBinary::LeftShift: - binary_int(a, b, out, detail::LeftShift(), stream()); + binary_int_op_cpu(a, b, out, detail::LeftShift(), stream()); break; case BitwiseBinary::RightShift: - binary_int(a, b, out, detail::RightShift(), stream()); + binary_int_op_cpu(a, b, out, detail::RightShift(), stream()); break; } } @@ -484,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); const auto& a = inputs[0]; const auto& b = inputs[1]; - binary_float(a, b, out, detail::ArcTan2(), stream()); + binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream()); } } // namespace mlx::core diff --git a/mlx/backend/cpu/binary.h b/mlx/backend/cpu/binary.h index 31966c0ea..acaca50e1 100644 --- a/mlx/backend/cpu/binary.h +++ b/mlx/backend/cpu/binary.h @@ -7,6 +7,7 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { @@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { binary_op(a, b, out, bopt); } +template +void binary_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void comparison_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (a.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void binary_float_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error( + "[binary_float] Only supports floating point types."); + } + }); +} + +template +void binary_int_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error("[binary_int] Type not supported"); + break; + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 0998c527c..7df331671 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -147,37 +147,8 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { copy_cpu(c, out, ctype, stream()); } else { array beta_scalar = array(beta_, c.dtype()); - auto bopt = get_binary_op_type(c, beta_scalar); - set_binary_op_output_data(c, beta_scalar, out, bopt); auto& encoder = cpu::get_command_encoder(stream()); - encoder.set_input_array(c); - encoder.set_input_array(beta_scalar); - encoder.set_output_array(out); - encoder.dispatch([c = array::unsafe_weak_copy(c), - beta_scalar = array::unsafe_weak_copy(beta_scalar), - out = array::unsafe_weak_copy(out), - bopt]() mutable { - switch (out.dtype()) { - case float16: - binary_op(c, beta_scalar, out, bopt); - break; - case float32: - binary_op(c, beta_scalar, out, bopt); - break; - case float64: - binary_op(c, beta_scalar, out, bopt); - break; - case bfloat16: - binary_op(c, beta_scalar, out, bopt); - break; - case complex64: - binary_op(c, beta_scalar, out, bopt); - break; - default: - throw std::runtime_error( - "[AddMM::eval_cpu] Unsupported dtype for beta scaling"); - } - }); + binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream()); encoder.add_temporary(std::move(beta_scalar)); } return;