2023-11-30 11:12:53 -08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
|
2023-11-29 10:42:59 -08:00
|
|
|
#include <cassert>
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
|
|
|
|
|
#include "mlx/allocator.h"
|
2025-02-03 15:58:02 -08:00
|
|
|
#include "mlx/backend/cpu/binary.h"
|
|
|
|
|
#include "mlx/backend/cpu/binary_ops.h"
|
|
|
|
|
#include "mlx/backend/cpu/binary_two.h"
|
2025-03-11 06:30:44 -07:00
|
|
|
#include "mlx/backend/cpu/encoder.h"
|
2023-11-29 10:42:59 -08:00
|
|
|
#include "mlx/primitives.h"
|
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Add(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void DivMod::eval_cpu(
|
2024-01-08 16:39:08 -08:00
|
|
|
const std::vector<array>& inputs,
|
|
|
|
|
std::vector<array>& outputs) {
|
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-03-11 06:30:44 -07:00
|
|
|
auto bopt = get_binary_op_type(a, b);
|
|
|
|
|
auto& out_a = outputs[0];
|
|
|
|
|
auto& out_b = outputs[1];
|
|
|
|
|
set_binary_op_output_data(a, b, out_a, bopt);
|
|
|
|
|
set_binary_op_output_data(a, b, out_b, bopt);
|
|
|
|
|
|
|
|
|
|
auto& encoder = cpu::get_command_encoder(stream());
|
|
|
|
|
encoder.set_input_array(a);
|
|
|
|
|
encoder.set_input_array(b);
|
|
|
|
|
encoder.set_output_array(out_a);
|
|
|
|
|
encoder.set_output_array(out_b);
|
|
|
|
|
|
|
|
|
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
|
|
|
b = array::unsafe_weak_copy(b),
|
|
|
|
|
out_a = array::unsafe_weak_copy(out_a),
|
|
|
|
|
out_b = array::unsafe_weak_copy(out_b),
|
|
|
|
|
bopt]() mutable {
|
|
|
|
|
auto integral_op = [](auto x, auto y) {
|
|
|
|
|
return std::make_pair(x / y, x % y);
|
|
|
|
|
};
|
|
|
|
|
auto float_op = [](auto x, auto y) {
|
|
|
|
|
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
switch (out_a.dtype()) {
|
|
|
|
|
case bool_:
|
|
|
|
|
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
case uint8:
|
|
|
|
|
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case uint16:
|
|
|
|
|
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case uint32:
|
|
|
|
|
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case uint64:
|
|
|
|
|
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case int8:
|
|
|
|
|
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case int16:
|
|
|
|
|
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case int32:
|
|
|
|
|
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case int64:
|
|
|
|
|
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case float16:
|
|
|
|
|
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case float32:
|
|
|
|
|
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case float64:
|
|
|
|
|
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case bfloat16:
|
|
|
|
|
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case complex64:
|
|
|
|
|
// Should never get here
|
|
|
|
|
throw std::runtime_error("[DivMod] Complex type not supported");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
});
|
2024-01-08 16:39:08 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-12-08 15:08:52 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
2023-12-08 15:08:52 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-01-29 14:34:49 -08:00
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2023-11-29 10:42:59 -08:00
|
|
|
if (equal_nan_) {
|
2025-03-11 06:30:44 -07:00
|
|
|
auto bopt = get_binary_op_type(a, b);
|
|
|
|
|
set_binary_op_output_data(a, b, out, bopt);
|
|
|
|
|
|
|
|
|
|
auto& encoder = cpu::get_command_encoder(stream());
|
|
|
|
|
encoder.set_input_array(a);
|
|
|
|
|
encoder.set_input_array(b);
|
|
|
|
|
encoder.set_output_array(out);
|
|
|
|
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
|
|
|
b = array::unsafe_weak_copy(b),
|
|
|
|
|
out = array::unsafe_weak_copy(out),
|
|
|
|
|
bopt]() mutable {
|
|
|
|
|
switch (a.dtype()) {
|
|
|
|
|
case float16:
|
|
|
|
|
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case float32:
|
|
|
|
|
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case float64:
|
|
|
|
|
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case bfloat16:
|
|
|
|
|
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
|
|
|
|
break;
|
|
|
|
|
case complex64:
|
|
|
|
|
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
"[NanEqual::eval_cpu] Only for floating point types.");
|
|
|
|
|
}
|
|
|
|
|
});
|
2023-11-29 10:42:59 -08:00
|
|
|
} else {
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(
|
|
|
|
|
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2024-06-14 09:46:55 -07:00
|
|
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
|
|
|
|
auto& in1 = inputs[0];
|
|
|
|
|
auto& in2 = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
2024-06-14 09:46:55 -07:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2024-06-14 09:46:55 -07:00
|
|
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
|
|
|
|
auto& in1 = inputs[0];
|
|
|
|
|
auto& in2 = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
2024-06-14 09:46:55 -07:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
2025-11-05 01:08:41 -08:00
|
|
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Power(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2023-11-29 10:42:59 -08:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
2023-11-29 10:42:59 -08:00
|
|
|
}
|
|
|
|
|
|
2024-04-26 22:03:42 -07:00
|
|
|
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
auto& a = inputs[0];
|
|
|
|
|
auto& b = inputs[1];
|
|
|
|
|
switch (op_) {
|
|
|
|
|
case BitwiseBinary::And:
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
2024-04-26 22:03:42 -07:00
|
|
|
break;
|
|
|
|
|
case BitwiseBinary::Or:
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
2024-04-26 22:03:42 -07:00
|
|
|
break;
|
|
|
|
|
case BitwiseBinary::Xor:
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
2024-04-26 22:03:42 -07:00
|
|
|
break;
|
|
|
|
|
case BitwiseBinary::LeftShift:
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
2024-04-26 22:03:42 -07:00
|
|
|
break;
|
|
|
|
|
case BitwiseBinary::RightShift:
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
2024-04-26 22:03:42 -07:00
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-01-29 14:34:49 -08:00
|
|
|
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
2024-05-08 11:35:15 -04:00
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
const auto& a = inputs[0];
|
|
|
|
|
const auto& b = inputs[1];
|
2025-11-05 01:08:41 -08:00
|
|
|
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
2024-05-08 11:35:15 -04:00
|
|
|
}
|
|
|
|
|
|
2023-11-29 10:42:59 -08:00
|
|
|
} // namespace mlx::core
|