// Copyright © 2023 Apple Inc. #include #include #include #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary_two.h" #include "mlx/backend/common/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { template void comparison_op(const array& a, const array& b, array& out, Op op) { DefaultScalarVector opsv(op); DefaultVectorScalar opvs(op); DefaultVectorVector opvv(op); binary_op(a, b, out, op, opsv, opvs, opvv); } template void comparison_op(const array& a, const array& b, array& out, Op op) { switch (a.dtype()) { case bool_: comparison_op(a, b, out, op); break; case uint8: comparison_op(a, b, out, op); break; case uint16: comparison_op(a, b, out, op); break; case uint32: comparison_op(a, b, out, op); break; case uint64: comparison_op(a, b, out, op); break; case int8: comparison_op(a, b, out, op); break; case int16: comparison_op(a, b, out, op); break; case int32: comparison_op(a, b, out, op); break; case int64: comparison_op(a, b, out, op); break; case float16: comparison_op(a, b, out, op); break; case float32: comparison_op(a, b, out, op); break; case bfloat16: comparison_op(a, b, out, op); break; case complex64: comparison_op(a, b, out, op); break; } } } // namespace void Add::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Add()); } void DivMod::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; auto integral_op = [](auto x, auto y) { return std::make_pair(x / y, x % y); }; auto float_op = [](auto x, auto y) { return std::make_pair(std::trunc(x / y), std::fmod(x, y)); }; switch (outputs[0].dtype()) { case bool_: binary_op(a, b, outputs, integral_op); case uint8: binary_op(a, b, outputs, integral_op); break; case uint16: binary_op(a, b, outputs, integral_op); break; case uint32: binary_op(a, b, outputs, integral_op); break; case uint64: binary_op(a, b, outputs, integral_op); break; case int8: binary_op(a, b, outputs, integral_op); break; case int16: binary_op(a, b, outputs, integral_op); break; case int32: binary_op(a, b, outputs, integral_op); break; case int64: binary_op(a, b, outputs, integral_op); break; case float16: binary_op(a, b, outputs, float_op); break; case float32: binary_op(a, b, outputs, float_op); break; case bfloat16: binary_op(a, b, outputs, float_op); break; case complex64: // Should never get here throw std::runtime_error("[DivMod] Complex type not supported"); break; } } void Divide::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Divide()); } void Remainder::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Remainder()); } void Equal::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (equal_nan_) { comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); } else { comparison_op(inputs[0], inputs[1], out, detail::Equal()); } } void Greater::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::Greater()); } void GreaterEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); } void Less::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::Less()); } void LessEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); } void LogAddExp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; if (out.dtype() == float32) { binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) { binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) { binary_op(a, b, out, detail::LogAddExp()); } else if (issubdtype(out.dtype(), inexact)) { std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype(); throw std::invalid_argument(err.str()); } else { throw std::invalid_argument( "[logaddexp] Cannot compute logaddexp for arrays with" " non floating point type."); } } void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Maximum()); } void Minimum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Minimum()); } void Multiply::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Multiply()); } void NotEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); } void Power::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Power()); } void Subtract::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Subtract()); } } // namespace mlx::core