diff --git a/CMakeLists.txt b/CMakeLists.txt index 58ccd0a60..4c1cc64ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,6 +147,7 @@ if(MLX_BUILD_CPU) if(MLX_BUILD_ACCELERATE) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) + add_compile_definitions(MLX_USE_ACCELERATE) add_compile_definitions(ACCELERATE_NEW_LAPACK) elseif(MLX_BUILD_BLAS_FROM_SOURCE) # Download and build OpenBLAS from source code. diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt index f718e19de..96afd3107 100644 --- a/mlx/backend/accelerate/CMakeLists.txt +++ b/mlx/backend/accelerate/CMakeLists.txt @@ -3,6 +3,4 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index c1c1e61ee..bfee1050f 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -11,448 +11,8 @@ #include "mlx/backend/common/unary.h" #include "mlx/primitives.h" -#define DEFAULT(primitive) \ - void primitive::eval_cpu(const std::vector& inputs, array& out) { \ - primitive::eval(inputs, out); \ - } - -#define DEFAULT_MULTI(primitive) \ - void primitive::eval_cpu( \ - const std::vector& inputs, std::vector& outputs) { \ - primitive::eval(inputs, outputs); \ - } - namespace mlx::core { -// Use the default implementation for the following primitives -DEFAULT(Arange) -DEFAULT(ArgPartition) -DEFAULT(ArgReduce) -DEFAULT(ArgSort) -DEFAULT(AsStrided) -DEFAULT(BlockMaskedMM) -DEFAULT(Broadcast) -DEFAULT(BroadcastAxes) -DEFAULT(Ceil) -DEFAULT(Concatenate) -DEFAULT(Conjugate) -DEFAULT(Copy) -DEFAULT_MULTI(CustomTransforms) -DEFAULT_MULTI(Depends) -DEFAULT_MULTI(DivMod) -DEFAULT(NumberOfElements) -DEFAULT(Equal) -DEFAULT(Erf) -DEFAULT(ErfInv) -DEFAULT(ExpandDims) -DEFAULT(FFT) -DEFAULT(Floor) -DEFAULT(Gather) -DEFAULT(GatherMM) -DEFAULT(GatherQMM) -DEFAULT(Greater) -DEFAULT(GreaterEqual) -DEFAULT(Hadamard) -DEFAULT(Less) -DEFAULT(LessEqual) -DEFAULT(Load) -DEFAULT(LogicalNot) -DEFAULT(LogicalAnd) -DEFAULT(LogicalOr) -DEFAULT(LogAddExp) -DEFAULT(Maximum) -DEFAULT(Minimum) -DEFAULT(NotEqual) -DEFAULT(Pad) -DEFAULT(Partition) -DEFAULT_MULTI(QRF) -DEFAULT(RandomBits) -DEFAULT(Remainder) -DEFAULT(Round) -DEFAULT(Scatter) -DEFAULT(Select) -DEFAULT(Sigmoid) -DEFAULT(Sign) -DEFAULT(Slice) -DEFAULT(SliceUpdate) -DEFAULT_MULTI(Split) -DEFAULT(Sort) -DEFAULT(Squeeze) -DEFAULT(StopGradient) -DEFAULT_MULTI(SVD) -DEFAULT(Transpose) -DEFAULT(Inverse) -DEFAULT(Cholesky) -DEFAULT_MULTI(Eigh) - -void Abs::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (in.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - vDSP_vabs(in.data(), 1, out.data(), 1, in.data_size()); - } else if (in.dtype() == int32 && in.flags().contiguous) { - set_unary_output_data(in, out); - vDSP_vabsi(in.data(), 1, out.data(), 1, in.data_size()); - } else { - eval(inputs, out); - } -} - -void Add::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x + y; }, - [](const auto* s, const auto* vec, auto* o, auto n) { - vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* vec, const auto* s, auto* o, auto n) { - vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); - }); - } else if (a.dtype() == int32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x + y; }, - [](const auto* s, const auto* vec, auto* o, auto n) { - vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); - }, - [](const auto* vec, const auto* s, auto* o, auto n) { - vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); - }); - } else { - eval(inputs, out); - } -} - -void ArcCos::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvacosf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvacoshf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcSin::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvasinf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvasinhf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcTan::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvatanf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - if (out.dtype() == float32 && a.flags().row_contiguous && - b.flags().row_contiguous) { - if (a.is_donatable()) { - out.copy_shared_buffer(a); - } else if (b.is_donatable()) { - out.copy_shared_buffer(b); - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - int size = a.data_size(); - vvatan2f(out.data(), a.data(), b.data(), &size); - } else { - eval(inputs, out); - } -} - -void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvatanhf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void AsType::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - - if (in.flags().contiguous) { - // Use accelerate functions if possible - if (in.dtype() == float32 && out.dtype() == uint32) { - set_unary_output_data(in, out); - vDSP_vfixu32( - in.data(), 1, out.data(), 1, in.data_size()); - return; - } else if (in.dtype() == float32 && out.dtype() == int32) { - set_unary_output_data(in, out); - vDSP_vfix32(in.data(), 1, out.data(), 1, in.data_size()); - return; - } else if (in.dtype() == uint32 && out.dtype() == float32) { - set_unary_output_data(in, out); - vDSP_vfltu32( - in.data(), 1, out.data(), 1, in.data_size()); - return; - } else if (in.dtype() == int32 && out.dtype() == float32) { - set_unary_output_data(in, out); - vDSP_vflt32(in.data(), 1, out.data(), 1, in.data_size()); - return; - } - } - eval(inputs, out); -} - -void Cos::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvcosf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void Cosh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvcoshf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void Divide::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == int32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x / y; }, - UseDefaultBinaryOp(), - [](const auto* vec, const auto* s, auto* o, auto n) { - vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); - }); - } else if (a.dtype() == float32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x / y; }, - [](const auto* s, const auto* vec, auto* o, auto n) { - vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n); - }, - [](const auto* vec, const auto* s, auto* o, auto n) { - vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); - }); - } else { - eval(inputs, out); - } -} - -void Exp::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - auto size = in.data_size(); - vvexpf(out.data(), in.data(), reinterpret_cast(&size)); - } else { - eval(inputs, out); - } -} - -void Expm1::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - auto size = in.data_size(); - vvexpm1f( - out.data(), in.data(), reinterpret_cast(&size)); - } else { - eval(inputs, out); - } -} - -void Full::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - assert(in.dtype() == out.dtype()); - if (in.data_size() == 1 && out.dtype() == float32) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - vDSP_vfill(in.data(), out.data(), 1, out.size()); - } else { - eval(inputs, out); - } -} - -void Log::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - auto size = in.data_size(); - switch (base_) { - case Base::e: - vvlogf( - out.data(), in.data(), reinterpret_cast(&size)); - break; - case Base::two: - vvlog2f( - out.data(), in.data(), reinterpret_cast(&size)); - break; - case Base::ten: - vvlog10f( - out.data(), in.data(), reinterpret_cast(&size)); - break; - } - } else { - eval(inputs, out); - } -} - -void Log1p::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - auto size = in.data_size(); - vvlog1pf( - out.data(), in.data(), reinterpret_cast(&size)); - } else { - eval(inputs, out); - } -} - -void Multiply::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x * y; }, - [](const auto* s, const auto* vec, auto* o, auto n) { - vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* vec, const auto* s, auto* o, auto n) { - vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); - }); - } else { - eval(inputs, out); - } -} - -void Negative::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (in.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - vDSP_vneg(in.data(), 1, out.data(), 1, in.data_size()); - } else { - eval(inputs, out); - } -} - -void Power::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - if (out.dtype() == float32 && a.flags().row_contiguous && - b.flags().row_contiguous) { - int size = a.size(); - if (a.is_donatable() && a.itemsize() == out.itemsize()) { - out.copy_shared_buffer(a); - } else if (b.is_donatable() && b.itemsize() == out.itemsize()) { - out.copy_shared_buffer(b); - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - vvpowf(out.data(), b.data(), a.data(), &size); - } else { - eval(inputs, out); - } -} - void Scan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -484,120 +44,4 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { } } -void Sin::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvsinf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void Sinh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvsinhf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void Square::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (in.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - auto size = in.data_size(); - vDSP_vsq(in.data(), 1, out.data(), 1, size); - } else { - eval(inputs, out); - } -} - -void Sqrt::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (in.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - if (recip_) { - vvrsqrtf(out.data(), in.data(), &size); - } else { - vvsqrtf(out.data(), in.data(), &size); - } - } else { - eval(inputs, out); - } -} - -void Subtract::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x - y; }, - [](const auto* s, const auto* vec, auto* o, auto n) { - float minus_1 = -1; - vDSP_vsmsa( - (const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n); - }, - [](const auto* vec, const auto* s, auto* o, auto n) { - float val = -(*s); - vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n); - }, - [](const auto* a, const auto* b, auto* o, auto n) { - vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); - }); - } else if (a.dtype() == int32) { - binary_op( - a, - b, - out, - [](auto x, auto y) { return x - y; }, - UseDefaultBinaryOp(), - [](const auto* vec, const auto* s, auto* o, auto n) { - int val = -(*s); - vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n); - }, - UseDefaultBinaryOp()); - } else { - eval(inputs, out); - } -} - -void Tan::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvtanf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - -void Tanh::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == float32 && in.flags().contiguous) { - set_unary_output_data(in, out); - int size = in.data_size(); - vvtanhf(out.data(), in.data(), &size); - } else { - eval(inputs, out); - } -} - } // namespace mlx::core diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp deleted file mode 100644 index 3c1312fbc..000000000 --- a/mlx/backend/accelerate/quantized.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include - -#include - -#include "mlx/primitives.h" - -namespace mlx::core { - -namespace { - -void _qmm_t_4_64( - float* result, - const float* x, - const uint32_t* w, - const float* scales, - const float* biases, - int M, - int N, - int K, - int B, - bool batched_w) { - constexpr int bits = 4; - constexpr int group_size = 64; - constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = 32 / bits; - constexpr int packs_in_group = group_size / pack_factor; - - int w_els = N * K / pack_factor; - int g_els = w_els * pack_factor / group_size; - - for (int i = 0; i < B; i++) { - for (int m = 0; m < M; m++) { - const uint32_t* w_local = w; - const float* scales_local = scales; - const float* biases_local = biases; - - for (int n = 0; n < N; n++) { - const simd_float16* x_local = (simd_float16*)x; - simd_float16 sum = 0; - for (int k = 0; k < K; k += group_size) { - float scale = *scales_local++; - float bias = *biases_local++; - - for (int kw = 0; kw < packs_in_group; kw += 2) { - // TODO: vectorize this properly - simd_uint16 wi; - for (int e = 0; e < 2; e++) { - uint32_t wii = *w_local++; - for (int p = 0; p < 8; p++) { - wi[e * 8 + p] = wii & bitmask; - wii >>= bits; - } - } - simd_float16 wf = simd_float(wi); - wf *= scale; - wf += bias; - - sum += (*x_local) * wf; - x_local++; - } - } - - *result = simd_reduce_add(sum); - result++; - } - - x += K; - } - if (batched_w) { - w += w_els; - scales += g_els; - biases += g_els; - } - } -} - -} // namespace - -void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - - auto& x = inputs[0]; - auto& w = inputs[1]; - auto& scales = inputs[2]; - auto& biases = inputs[3]; - - bool condition = - (transpose_ && x.flags().row_contiguous && w.flags().row_contiguous && - scales.flags().row_contiguous && biases.flags().row_contiguous && - x.dtype() == float32 && bits_ == 4 && group_size_ == 64); - - if (condition) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - int K = x.shape(-1); - int M = x.shape(-2); - int N = out.shape(-1); - int B = x.size() / K / M; - bool batched_w = w.ndim() > 2; - _qmm_t_4_64( - out.data(), - x.data(), - w.data(), - scales.data(), - biases.data(), - M, - N, - K, - B, - batched_w); - } else { - eval(inputs, out); - } -} - -} // namespace mlx::core diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp deleted file mode 100644 index 8326ba1c3..000000000 --- a/mlx/backend/accelerate/softmax.cpp +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#include -#endif - -#include -#include - -#include "mlx/backend/common/copy.h" -#include "mlx/primitives.h" - -namespace mlx::core { - -namespace { - -/** - * Compute exp(x) in an optimizer friendly way as follows: - * - * First change the problem to computing 2**y where y = x / ln(2). - * - * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part - * `ipart` and y2 is fractional part. For the integer part we perform bit - * shifting and for the fractional part we use a polynomial approximation. - * - * The algorithm and constants of the polynomial taken from - * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them - * from Cephes math library. - * - * Note: The implementation below is a general fast exp. There could be faster - * implementations for numbers strictly < 0. - */ -inline simd_float16 simd_fast_exp(simd_float16 x_init) { - auto x = x_init * 1.442695; // multiply with log_2(e) - simd_float16 ipart, fpart; - simd_int16 epart; - x = simd_clamp(x, -80, 80); - ipart = simd::floor(x + 0.5); - fpart = x - ipart; - - x = 1.535336188319500e-4f; - x = x * fpart + 1.339887440266574e-3f; - x = x * fpart + 9.618437357674640e-3f; - x = x * fpart + 5.550332471162809e-2f; - x = x * fpart + 2.402264791363012e-1f; - x = x * fpart + 6.931472028550421e-1f; - x = x * fpart + 1.000000000000000f; - - // generate 2**ipart in the floating point representation using integer - // bitshifting - epart = (simd_int(ipart) + 127) << 23; - - // Avoid supressing NaNs - simd_int16 eq = (x_init == x_init); - return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq); -} - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -/** - * The ARM neon equivalent of the fast exp above. - */ -inline float16x8_t neon_fast_exp(float16x8_t x) { - x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e) - x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14 - x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14 - - float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f)))); - float16x8_t fpart = vsubq_f16(x, ipart); - - x = vdupq_n_f16(float16_t(1.535336188319500e-4f)); - x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart); - - // generate 2**ipart in the floating point representation using integer - // bitshifting - int16x8_t epart = vcvtq_s16_f16(ipart); - epart = vaddq_s16(epart, vdupq_n_s16(15)); - epart = vshlq_n_s16(epart, 10); - - return vmulq_f16(vreinterpretq_f16_s16(epart), x); -} - -/** - * Implementation of folding maximum for ARM neon. This should possibly be - * refactored out of softmax.cpp at some point. - */ -inline float16_t neon_reduce_max(float16x8_t x) { - float16x4_t y; - y = vpmax_f16(vget_low_f16(x), vget_high_f16(x)); - y = vpmax_f16(y, y); - y = vpmax_f16(y, y); - return vget_lane_f16(y, 0); -} - -/** - * Implementation of folding sum for ARM neon. This should possibly be - * refactored out of softmax.cpp at some point. - */ -inline float16_t neon_reduce_add(float16x8_t x) { - float16x4_t y; - float16x4_t zero = vdup_n_f16(0); - y = vpadd_f16(vget_low_f16(x), vget_high_f16(x)); - y = vpadd_f16(y, zero); - y = vpadd_f16(y, zero); - return vget_lane_f16(y, 0); -} - -template -struct NeonFp16SimdOps { - VT init(T a) { - return vdupq_n_f16(a); - } - - VT load(const T* a) { - return vld1q_f16(a); - } - - void store(T* dst, VT x) { - vst1q_f16(dst, x); - } - - VT max(VT a, VT b) { - return vmaxq_f16(a, b); - } - - VT exp(VT x) { - return neon_fast_exp(x); - } - - VT add(VT a, VT b) { - return vaddq_f16(a, b); - } - - VT sub(VT a, T b) { - return vsubq_f16(a, vdupq_n_f16(b)); - } - - VT mul(VT a, VT b) { - return vmulq_f16(a, b); - } - - VT mul(VT a, T b) { - return vmulq_f16(a, vdupq_n_f16(b)); - } - - T reduce_max(VT x) { - return neon_reduce_max(x); - } - - T reduce_add(VT x) { - return neon_reduce_add(x); - } -}; - -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template -struct AccelerateSimdOps { - VT init(T a) { - return a; - } - - VT load(const T* a) { - return *(VT*)a; - } - - void store(T* dst, VT x) { - *(VT*)dst = x; - } - - VT max(VT a, VT b) { - return simd_max(a, b); - } - - VT exp(VT x) { - return simd_fast_exp(x); - } - - VT add(VT a, VT b) { - return a + b; - } - - VT sub(VT a, T b) { - return a - b; - } - - VT mul(VT a, VT b) { - return a * b; - } - - VT mul(VT a, T b) { - return a * b; - } - - T reduce_max(VT x) { - return simd_reduce_max(x); - } - - T reduce_add(VT x) { - return simd_reduce_add(x); - } -}; - -template -void softmax(const array& in, array& out) { - Ops ops; - - const T* in_ptr = in.data(); - T* out_ptr = out.data(); - int M = in.shape().back(); - int L = in.data_size() / M; - const T* current_in_ptr; - T* current_out_ptr; - - for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { - // Find the maximum - current_in_ptr = in_ptr; - VT vmaximum = ops.init(-std::numeric_limits::infinity()); - size_t s = M; - while (s >= N) { - VT vals; - if constexpr (std::is_same::value) { - vals = ops.load(current_in_ptr); - } else { - for (int i = 0; i < N; ++i) { - vals[i] = static_cast(current_in_ptr[i]); - } - } - vmaximum = ops.max(vals, vmaximum); - current_in_ptr += N; - s -= N; - } - AccT maximum = ops.reduce_max(vmaximum); - while (s-- > 0) { - maximum = std::max(maximum, static_cast(*current_in_ptr)); - current_in_ptr++; - } - - // Compute the normalizer and the exponentials - VT vnormalizer = ops.init(0.0); - current_out_ptr = out_ptr; - current_in_ptr = in_ptr; - s = M; - while (s >= N) { - VT vexp; - if constexpr (std::is_same::value) { - vexp = ops.load(current_in_ptr); - } else { - for (int i = 0; i < N; ++i) { - vexp[i] = static_cast(current_in_ptr[i]); - } - } - vexp = ops.exp(ops.sub(vexp, maximum)); - if constexpr (std::is_same::value) { - ops.store(current_out_ptr, vexp); - } - vnormalizer = ops.add(vnormalizer, vexp); - current_in_ptr += N; - current_out_ptr += N; - s -= N; - } - AccT normalizer = ops.reduce_add(vnormalizer); - while (s-- > 0) { - AccT _exp = std::exp(*current_in_ptr - maximum); - if (std::is_same::value) { - *current_out_ptr = _exp; - } - normalizer += _exp; - current_in_ptr++; - current_out_ptr++; - } - normalizer = 1 / normalizer; - - // Normalize - current_out_ptr = out_ptr; - current_in_ptr = in_ptr; - s = M; - while (s >= N) { - if constexpr (std::is_same::value) { - ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); - } else { - VT vexp; - for (int i = 0; i < N; ++i) { - vexp[i] = static_cast(current_in_ptr[i]); - } - vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer); - for (int i = 0; i < N; ++i) { - current_out_ptr[i] = vexp[i]; - } - current_in_ptr += N; - } - current_out_ptr += N; - s -= N; - } - while (s-- > 0) { - if constexpr (std::is_same::value) { - *current_out_ptr *= normalizer; - } else { - AccT _exp = std::exp(*current_in_ptr - maximum); - *current_out_ptr = static_cast(_exp * normalizer); - current_in_ptr++; - } - current_out_ptr++; - } - } -} - -} // namespace - -void Softmax::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - - // Make sure that the last dimension is contiguous - auto check_input = [](array x) { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); - } - if (no_copy) { - return x; - } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy(x, x_copy, CopyType::General); - return x_copy; - } - }; - array in = check_input(std::move(inputs[0])); - out.set_data( - allocator::malloc_or_wait(in.data_size() * in.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - - switch (in.dtype()) { - case bool_: - case uint8: - case uint16: - case uint32: - case uint64: - case int8: - case int16: - case int32: - case int64: - throw std::invalid_argument( - "Softmax is defined only for floating point types"); - break; - case float32: - softmax< - float, - float, - simd_float16, - AccelerateSimdOps, - 16>(in, out); - break; - case float16: - if (precise_) { - softmax< - float16_t, - float, - simd_float16, - AccelerateSimdOps, - 16>(in, out); - } else { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - softmax< - float16_t, - float16_t, - float16x8_t, - NeonFp16SimdOps, - 8>(in, out); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - eval(inputs, out); // Redirect to common backend for consistency -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - } - break; - case bfloat16: - eval(inputs, out); - break; - case complex64: - eval(inputs, out); - break; - } -} - -} // namespace mlx::core diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index e32123f43..97fc48008 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -5,6 +5,18 @@ else() set(COMPILER ${CMAKE_CXX_COMPILER}) endif() +set(COMPILE_DEPS + ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h + ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h + ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h + ${PROJECT_SOURCE_DIR}/mlx/types/complex.h + simd/simd.h + simd/base_simd.h + simd/math.h + simd/type.h + unary_ops.h + binary_ops.h) + if(MSVC) set(SHELL_EXT ps1) set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) @@ -19,13 +31,8 @@ add_custom_command( ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} - DEPENDS make_compiled_preamble.${SHELL_EXT} - compiled_preamble.h - ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h - ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h - ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h - ${PROJECT_SOURCE_DIR}/mlx/types/complex.h - ops.h) + DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h + ${COMPILE_DEPS}) add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) @@ -60,6 +67,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) diff --git a/mlx/backend/common/arg_reduce.cpp b/mlx/backend/common/arg_reduce.cpp index 00f78136c..4d66796e1 100644 --- a/mlx/backend/common/arg_reduce.cpp +++ b/mlx/backend/common/arg_reduce.cpp @@ -61,7 +61,7 @@ void arg_reduce_dispatch( } // namespace -void ArgReduce::eval(const std::vector& inputs, array& out) { +void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 517c7e8a0..6178328d1 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -6,8 +6,8 @@ #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" +#include "mlx/backend/common/binary_ops.h" #include "mlx/backend/common/binary_two.h" -#include "mlx/backend/common/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -15,69 +15,61 @@ namespace mlx::core { namespace { -template -void comparison_op(const array& a, const array& b, array& out, Op op) { - DefaultScalarVector opsv(op); - DefaultVectorScalar opvs(op); - DefaultVectorVector opvv(op); - binary_op(a, b, out, op, opsv, opvs, opvv); -} - template void comparison_op(const array& a, const array& b, array& out, Op op) { switch (a.dtype()) { case bool_: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case uint8: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case uint16: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case uint32: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case uint64: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case int8: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case int16: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case int32: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case int64: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case float16: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case float32: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case bfloat16: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; case complex64: - comparison_op(a, b, out, op); + binary_op(a, b, out, op); break; } } } // namespace -void Add::eval(const std::vector& inputs, array& out) { +void Add::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Add()); } -void DivMod::eval( +void DivMod::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 2); @@ -132,50 +124,68 @@ void DivMod::eval( } } -void Divide::eval(const std::vector& inputs, array& out) { +void Divide::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Divide()); } -void Remainder::eval(const std::vector& inputs, array& out) { +void Remainder::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Remainder()); } -void Equal::eval(const std::vector& inputs, array& out) { +void Equal::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; if (equal_nan_) { - comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); + switch (a.dtype()) { + case float16: + binary_op(a, b, out, detail::NaNEqual()); + break; + case float32: + binary_op(a, b, out, detail::NaNEqual()); + break; + case bfloat16: + binary_op(a, b, out, detail::NaNEqual()); + break; + case complex64: + binary_op(a, b, out, detail::NaNEqual()); + break; + default: + throw std::runtime_error( + "[NanEqual::eval_cpu] Only for floating point types."); + } } else { - comparison_op(inputs[0], inputs[1], out, detail::Equal()); + comparison_op(a, b, out, detail::Equal()); } } -void Greater::eval(const std::vector& inputs, array& out) { +void Greater::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::Greater()); } -void GreaterEqual::eval(const std::vector& inputs, array& out) { +void GreaterEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); } -void Less::eval(const std::vector& inputs, array& out) { +void Less::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::Less()); } -void LessEqual::eval(const std::vector& inputs, array& out) { +void LessEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); } -void LogAddExp::eval(const std::vector& inputs, array& out) { +void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; @@ -196,54 +206,54 @@ void LogAddExp::eval(const std::vector& inputs, array& out) { } } -void LogicalAnd::eval(const std::vector& inputs, array& out) { +void LogicalAnd::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; binary(in1, in2, out, detail::LogicalAnd()); } -void LogicalOr::eval(const std::vector& inputs, array& out) { +void LogicalOr::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; binary(in1, in2, out, detail::LogicalOr()); } -void Maximum::eval(const std::vector& inputs, array& out) { +void Maximum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Maximum()); } -void Minimum::eval(const std::vector& inputs, array& out) { +void Minimum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Minimum()); } -void Multiply::eval(const std::vector& inputs, array& out) { +void Multiply::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Multiply()); } -void NotEqual::eval(const std::vector& inputs, array& out) { +void NotEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); } -void Power::eval(const std::vector& inputs, array& out) { +void Power::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; binary(a, b, out, detail::Power()); } -void Subtract::eval(const std::vector& inputs, array& out) { +void Subtract::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; @@ -307,7 +317,7 @@ void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { } } -void ArcTan2::eval(const std::vector& inputs, array& out) { +void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); const auto& a = inputs[0]; const auto& b = inputs[1]; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index d6879cda4..e28db35e1 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -7,6 +7,8 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/common/simd/simd.h" + namespace mlx::core { namespace { @@ -122,16 +124,22 @@ void set_binary_op_output_data( } } -struct UseDefaultBinaryOp {}; - -template -struct DefaultVectorScalar { +template +struct VectorScalar { Op op; - DefaultVectorScalar(Op op_) : op(op_) {} + VectorScalar(Op op_) : op(op_) {} + template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *b; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::load(a), simd::Simd(scalar))); + dst += N; + a += N; + size -= N; + } while (size-- > 0) { *dst = op(*a, scalar); dst++; @@ -140,14 +148,22 @@ struct DefaultVectorScalar { } }; -template -struct DefaultScalarVector { +template +struct ScalarVector { Op op; - DefaultScalarVector(Op op_) : op(op_) {} + ScalarVector(Op op_) : op(op_) {} + template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *a; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::Simd(scalar), simd::load(b))); + dst += N; + b += N; + size -= N; + } while (size-- > 0) { *dst = op(scalar, *b); dst++; @@ -156,13 +172,22 @@ struct DefaultScalarVector { } }; -template -struct DefaultVectorVector { +template +struct VectorVector { Op op; - DefaultVectorVector(Op op_) : op(op_) {} + VectorVector(Op op_) : op(op_) {} + template void operator()(const T* a, const T* b, U* dst, int size) { + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::load(a), simd::load(b))); + dst += N; + a += N; + b += N; + size -= N; + } while (size-- > 0) { *dst = op(*a, *b); dst++; @@ -277,21 +302,8 @@ void binary_op_dispatch_dims( } } -template < - typename T, - typename U, - typename Op, - typename OpSV, - typename OpVS, - typename OpVV> -void binary_op( - const array& a, - const array& b, - array& out, - Op op, - OpSV opsv, - OpVS opvs, - OpVV opvv) { +template +void binary_op(const array& a, const array& b, array& out, Op op) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); @@ -303,19 +315,19 @@ void binary_op( // The full computation is scalar vector so delegate to the op if (bopt == BinaryOpType::ScalarVector) { - opsv(a.data(), b.data(), out.data(), b.data_size()); + ScalarVector{op}(a.data(), b.data(), out.data(), b.data_size()); return; } // The full computation is vector scalar so delegate to the op if (bopt == BinaryOpType::VectorScalar) { - opvs(a.data(), b.data(), out.data(), a.data_size()); + VectorScalar{op}(a.data(), b.data(), out.data(), a.data_size()); return; } // The full computation is vector vector so delegate to the op if (bopt == BinaryOpType::VectorVector) { - opvv(a.data(), b.data(), out.data(), out.size()); + VectorVector{op}(a.data(), b.data(), out.data(), out.size()); return; } @@ -376,15 +388,39 @@ void binary_op( switch (bopt) { case BinaryOpType::VectorVector: binary_op_dispatch_dims( - a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides); + a, + b, + out, + VectorVector{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); break; case BinaryOpType::VectorScalar: binary_op_dispatch_dims( - a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides); + a, + b, + out, + VectorScalar{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); break; case BinaryOpType::ScalarVector: binary_op_dispatch_dims( - a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides); + a, + b, + out, + ScalarVector{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); break; default: binary_op_dispatch_dims( @@ -393,134 +429,52 @@ void binary_op( } } -template -void binary_op( - const array& a, - const array& b, - array& out, - Op op, - OpSV opsv, - OpVS opvs, - OpVV opvv) { - // TODO: The following mess of constexpr evaluations can probably be achieved - // with template specializations and overloading. Would it be simpler? - - if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - // All ops are UseDefaultBinaryOp (why oh why would someone call that?) - binary_op( - a, - b, - out, - op, - DefaultScalarVector(op), - DefaultVectorScalar(op), - DefaultVectorVector(op)); - } else { - // opsv and opvs were UseDefaultBinaryOp - binary_op( - a, - b, - out, - op, - DefaultScalarVector(op), - DefaultVectorScalar(op), - opvv); - } - } else if constexpr (std::is_same:: - value) { - // opsv and opvv were UseDefaultBinaryOp - binary_op( - a, - b, - out, - op, - DefaultScalarVector(op), - opvs, - DefaultVectorVector(op)); - } else { - // opsv was UseDefaultBinaryOp - binary_op( - a, b, out, op, DefaultScalarVector(op), opvs, opvv); - } - } else if constexpr (std::is_same:: - value) { - if (std::is_same::value) { - // opvs and opvv were UseDefaultBinaryOp - binary_op( - a, - b, - out, - op, - opsv, - DefaultVectorScalar(op), - DefaultVectorVector(op)); - } else { - // opvs was UseDefaultBinaryOp - binary_op( - a, b, out, op, opsv, DefaultVectorScalar(op), opvv); - } - } else if constexpr (std::is_same:: - value) { - // opvv was UseDefaultBinaryOp - binary_op( - a, b, out, op, opsv, opvs, DefaultVectorVector(op)); - } else { - // All ops provided - binary_op(a, b, out, op, opsv, opvs, opvv); - } -} - template void binary_op(const array& a, const array& b, array& out, Op op) { - DefaultScalarVector opsv(op); - DefaultVectorScalar opvs(op); - DefaultVectorVector opvv(op); - binary_op(a, b, out, op, opsv, opvs, opvv); + binary_op(a, b, out, op); } -template -void binary(const array& a, const array& b, array& out, Ops... ops) { +template +void binary(const array& a, const array& b, array& out, Op op) { switch (out.dtype()) { case bool_: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case uint8: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case uint16: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case uint32: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case uint64: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case int8: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case int16: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case int32: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case int64: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case float16: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case float32: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case bfloat16: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; case complex64: - binary_op(a, b, out, ops...); + binary_op(a, b, out, op); break; } } diff --git a/mlx/backend/common/binary_ops.h b/mlx/backend/common/binary_ops.h new file mode 100644 index 000000000..fd10264f9 --- /dev/null +++ b/mlx/backend/common/binary_ops.h @@ -0,0 +1,98 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define BINARY_SINGLE() \ + template \ + T operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } + +#define DEFAULT_BINARY_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + BINARY_SINGLE() \ + }; + +DEFAULT_BINARY_OP(Add, operator+) +DEFAULT_BINARY_OP(ArcTan2, atan2) +DEFAULT_BINARY_OP(Divide, operator/) +DEFAULT_BINARY_OP(Multiply, operator*) +DEFAULT_BINARY_OP(Subtract, operator-) +DEFAULT_BINARY_OP(LogicalAnd, operator&&) +DEFAULT_BINARY_OP(LogicalOr, operator||) +DEFAULT_BINARY_OP(BitwiseAnd, operator&) +DEFAULT_BINARY_OP(BitwiseOr, operator|) +DEFAULT_BINARY_OP(BitwiseXor, operator^) +DEFAULT_BINARY_OP(LeftShift, operator<<) +DEFAULT_BINARY_OP(RightShift, operator>>) +DEFAULT_BINARY_OP(Remainder, remainder) +DEFAULT_BINARY_OP(Maximum, maximum) +DEFAULT_BINARY_OP(Minimum, minimum) +DEFAULT_BINARY_OP(Power, pow) + +#define DEFAULT_BOOL_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + template \ + bool operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } \ + }; + +DEFAULT_BOOL_OP(Equal, operator==) +DEFAULT_BOOL_OP(Greater, operator>) +DEFAULT_BOOL_OP(GreaterEqual, operator>=) +DEFAULT_BOOL_OP(Less, operator<) +DEFAULT_BOOL_OP(LessEqual, operator<=) +DEFAULT_BOOL_OP(NotEqual, operator!=) + +struct NaNEqual { + template + Simd operator()(Simd x, Simd y) { + return x == y || (isnan(x) && isnan(y)); + } + template + bool operator()(T x, T y) { + return (*this)(Simd(x), Simd(y)).value; + } +}; + +struct LogAddExp { + template + Simd operator()(Simd x, Simd y) { + auto maxval = maximum(x, y); + auto minval = minimum(x, y); + auto mask = minval == -inf || maxval == inf; + auto out = maxval + log1p(exp(minval - maxval)); + return select(mask, Simd(maxval), Simd(out)); + } + BINARY_SINGLE() +}; + +struct Select { + template + T operator()(bool condition, T x, T y) { + return (*this)(Simd(condition), Simd(x), Simd(y)) + .value; + } + + template + Simd operator()(Simd condition, Simd x, Simd y) { + return select(condition, x, y); + } +}; + +} // namespace mlx::core::detail diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 62807e6dd..ca09d9663 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -64,7 +64,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) { } } -void Cholesky::eval(const std::vector& inputs, array& output) { +void Cholesky::eval_cpu(const std::vector& inputs, array& output) { if (inputs[0].dtype() != float32) { throw std::runtime_error("[Cholesky::eval] only supports float32."); } diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h index 84b77d29d..feea71bcb 100644 --- a/mlx/backend/common/compiled_preamble.h +++ b/mlx/backend/common/compiled_preamble.h @@ -5,7 +5,8 @@ // clang-format off #include "mlx/types/half_types.h" #include "mlx/types/complex.h" -#include "mlx/backend/common/ops.h" +#include "mlx/backend/common/unary_ops.h" +#include "mlx/backend/common/binary_ops.h" // clang-format on const char* get_kernel_preamble(); diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 80cbc9f56..41bffdcea 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -4,6 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" namespace mlx::core { @@ -23,6 +24,7 @@ template void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); + size_t size = src.data_size(); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index f7548a12f..899de35cd 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -21,98 +21,9 @@ namespace mlx::core { -DEFAULT(Abs) -DEFAULT(Add) -DEFAULT(Arange) -DEFAULT(ArcCos) -DEFAULT(ArcCosh) -DEFAULT(ArcSin) -DEFAULT(ArcSinh) -DEFAULT(ArcTan) -DEFAULT(ArcTan2) -DEFAULT(ArcTanh) -DEFAULT(ArgPartition) -DEFAULT(ArgReduce) -DEFAULT(ArgSort) -DEFAULT(AsType) -DEFAULT(AsStrided) -DEFAULT(Broadcast) -DEFAULT(BroadcastAxes) -DEFAULT(BlockMaskedMM) -DEFAULT(GatherMM) -DEFAULT(GatherQMM) -DEFAULT_MULTI(DivMod) -DEFAULT(Ceil) -DEFAULT(Concatenate) -DEFAULT(Conjugate) DEFAULT(Convolution) -DEFAULT(Copy) -DEFAULT(Cos) -DEFAULT(Cosh) -DEFAULT_MULTI(CustomTransforms) -DEFAULT_MULTI(Depends) -DEFAULT(Divide) -DEFAULT(NumberOfElements) -DEFAULT(Remainder) -DEFAULT(Equal) -DEFAULT(Erf) -DEFAULT(ErfInv) -DEFAULT(Exp) -DEFAULT(ExpandDims) -DEFAULT(Expm1) -DEFAULT(FFT) -DEFAULT(Floor) -DEFAULT(Full) -DEFAULT(Gather) -DEFAULT(Greater) -DEFAULT(GreaterEqual) -DEFAULT(Hadamard) -DEFAULT(Less) -DEFAULT(LessEqual) -DEFAULT(Load) -DEFAULT(Log) -DEFAULT(Log1p) -DEFAULT(LogicalNot) -DEFAULT(LogicalAnd) -DEFAULT(LogicalOr) -DEFAULT(LogAddExp) -DEFAULT(Maximum) -DEFAULT(Minimum) -DEFAULT(Multiply) -DEFAULT(Negative) -DEFAULT(NotEqual) -DEFAULT(Pad) -DEFAULT(Partition) -DEFAULT(Power) -DEFAULT_MULTI(QRF) -DEFAULT(QuantizedMatmul) -DEFAULT(RandomBits) DEFAULT(Reduce) -DEFAULT(Round) DEFAULT(Scan) -DEFAULT(Scatter) -DEFAULT(Select) -DEFAULT(Sigmoid) -DEFAULT(Sign) -DEFAULT(Sin) -DEFAULT(Sinh) -DEFAULT(Slice) -DEFAULT(SliceUpdate) -DEFAULT(Softmax) -DEFAULT(Sort) -DEFAULT_MULTI(Split) -DEFAULT(Square) -DEFAULT(Squeeze) -DEFAULT(Sqrt) -DEFAULT(StopGradient) -DEFAULT(Subtract) -DEFAULT_MULTI(SVD) -DEFAULT(Tan) -DEFAULT(Tanh) -DEFAULT(Transpose) -DEFAULT(Inverse) -DEFAULT(Cholesky) -DEFAULT_MULTI(Eigh) namespace { diff --git a/mlx/backend/common/eigh.cpp b/mlx/backend/common/eigh.cpp index 8a4e499a3..7fa7b7fa8 100644 --- a/mlx/backend/common/eigh.cpp +++ b/mlx/backend/common/eigh.cpp @@ -45,7 +45,9 @@ void ssyevd( } // namespace -void Eigh::eval(const std::vector& inputs, std::vector& outputs) { +void Eigh::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { const auto& a = inputs[0]; auto& values = outputs[0]; diff --git a/mlx/backend/common/fft.cpp b/mlx/backend/common/fft.cpp index e46fd7f92..52bf80655 100644 --- a/mlx/backend/common/fft.cpp +++ b/mlx/backend/common/fft.cpp @@ -8,7 +8,7 @@ namespace mlx::core { -void FFT::eval(const std::vector& inputs, array& out) { +void FFT::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; std::vector strides_in( in.strides().begin(), in.strides().end()); diff --git a/mlx/backend/common/hadamard.cpp b/mlx/backend/common/hadamard.cpp index 6c71eaf9d..4ee05345b 100644 --- a/mlx/backend/common/hadamard.cpp +++ b/mlx/backend/common/hadamard.cpp @@ -82,7 +82,7 @@ void hadamard(array& out, int n, int m, float scale) { } } -void Hadamard::eval(const std::vector& inputs, array& out) { +void Hadamard::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -104,4 +104,4 @@ void Hadamard::eval(const std::vector& inputs, array& out) { } } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index 9519bf891..b0e354e32 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -162,7 +162,7 @@ void dispatch_gather( } } -void Gather::eval(const std::vector& inputs, array& out) { +void Gather::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto& src = inputs[0]; @@ -337,7 +337,7 @@ void dispatch_scatter( } } -void Scatter::eval(const std::vector& inputs, array& out) { +void Scatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 96dbfc001..23e294201 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -110,7 +110,7 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) { } } -void Inverse::eval(const std::vector& inputs, array& output) { +void Inverse::eval_cpu(const std::vector& inputs, array& output) { if (inputs[0].dtype() != float32) { throw std::runtime_error("[Inverse::eval] only supports float32."); } diff --git a/mlx/backend/common/lapack.h b/mlx/backend/common/lapack.h index 42d937785..dc262a0ff 100644 --- a/mlx/backend/common/lapack.h +++ b/mlx/backend/common/lapack.h @@ -11,7 +11,7 @@ #define lapack_complex_double std::complex #endif -#ifdef ACCELERATE_NEW_LAPACK +#ifdef MLX_USE_ACCELERATE #include #else #include diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 2a10ed08a..1fc6c9b31 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -1,12 +1,9 @@ // Copyright © 2023 Apple Inc. #include -#include #include -#include "mlx/allocator.h" #include "mlx/backend/common/load.h" -#include "mlx/primitives.h" namespace { @@ -51,11 +48,4 @@ void load( } } -void Load::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - load(out, offset_, reader_, swap_endianness_); -} - } // namespace mlx::core diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index f6c8300f2..5675399a3 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -53,7 +53,7 @@ inline void mask_matrix( } // namespace -void BlockMaskedMM::eval(const std::vector& inputs, array& out) { +void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() != float32) { throw std::runtime_error( "[BlockMaskedMM::eval] Currently only supports float32."); @@ -210,7 +210,7 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { } } -void GatherMM::eval(const std::vector& inputs, array& out) { +void GatherMM::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() != float32) { throw std::runtime_error( "[GatherMM::eval] Currently only supports float32."); diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h deleted file mode 100644 index 115386ac5..000000000 --- a/mlx/backend/common/ops.h +++ /dev/null @@ -1,680 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once -#include -#include -#include - -namespace mlx::core::detail { - -namespace { -constexpr float inf = std::numeric_limits::infinity(); -} // namespace - -typedef union { - int i; - float f; -} IntOrFloat; - -inline float fast_exp(float x) { - if (x == -std::numeric_limits::infinity()) { - return 0.0f; - } else if (x == std::numeric_limits::infinity() || std::isnan(x)) { - return x; - } - x *= 1.442695; // multiply with log_2(e) - float ipart, fpart; - IntOrFloat epart; - x = std::max(-80.f, std::min(x, 80.f)); - ipart = std::floor(x + 0.5); - fpart = x - ipart; - - x = 1.535336188319500e-4f; - x = x * fpart + 1.339887440266574e-3f; - x = x * fpart + 9.618437357674640e-3f; - x = x * fpart + 5.550332471162809e-2f; - x = x * fpart + 2.402264791363012e-1f; - x = x * fpart + 6.931472028550421e-1f; - x = x * fpart + 1.000000000000000f; - - // generate 2**ipart in the floating point representation using integer - // bitshifting - epart.i = (int(ipart) + 127) << 23; - - return epart.f * x; -} - -inline float fast_erf(float a) { - float r, s, t, u; - t = std::abs(a); - s = a * a; - if (t > 0.927734375f) { - // maximum error 0.99527 ulp - r = std::fma( - -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 - u = std::fma( - -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 - r = std::fma(r, s, u); - r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 - r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 - r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 - r = std::fma(r, t, -t); - // TODO, replace with expm1 when implemented - r = 1.0f - std::exp(r); - r = std::copysign(r, a); - } else { - // maximum error 0.98929 ulp - r = -5.96761703e-4f; // -0x1.38e000p-11 - r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 - r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 - r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 - r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 - r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 - r = std::fma(r, a, a); - } - return r; -} - -inline float fast_erfinv(float a) { - auto t = std::fma(a, 0.0f - a, 1.0f); - t = std::log(t); - float p; - if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 - p = 3.03697567e-10f; // 0x1.4deb44p-32 - p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 - p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 - p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 - p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 - p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 - p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 - p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 - p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 - } else { // maximum ulp error = 2.35002 - p = 5.43877832e-9f; // 0x1.75c000p-28 - p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 - p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 - p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 - p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 - p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 - p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 - p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 - p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 - p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 - } - return a * p; -} - -struct Abs { - template - T operator()(T x) { - return std::abs(x); - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct ArcCos { - template - T operator()(T x) { - return std::acos(x); - } -}; - -struct ArcCosh { - template - T operator()(T x) { - return std::acosh(x); - } -}; - -struct ArcSin { - template - T operator()(T x) { - return std::asin(x); - } -}; - -struct ArcSinh { - template - T operator()(T x) { - return std::asinh(x); - } -}; - -struct ArcTan { - template - T operator()(T x) { - return std::atan(x); - } -}; - -struct ArcTan2 { - template - T operator()(T y, T x) { - return std::atan2(y, x); - } -}; - -struct ArcTanh { - template - T operator()(T x) { - return std::atanh(x); - } -}; - -struct Ceil { - template - T operator()(T x) { - return std::ceil(x); - } - int8_t operator()(int8_t x) { - return x; - } - int16_t operator()(int16_t x) { - return x; - } - int32_t operator()(int32_t x) { - return x; - } - int64_t operator()(int64_t x) { - return x; - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct Conjugate { - complex64_t operator()(complex64_t x) { - return std::conj(x); - } -}; - -struct Cos { - template - T operator()(T x) { - return std::cos(x); - } -}; - -struct Cosh { - template - T operator()(T x) { - return std::cosh(x); - } -}; - -struct Erf { - template - T operator()(T x) { - return static_cast(fast_erf(static_cast(x))); - } -}; - -struct ErfInv { - template - T operator()(T x) { - return static_cast(fast_erfinv(static_cast(x))); - } -}; - -struct Exp { - template - T operator()(T x) { - return fast_exp(x); - } - - complex64_t operator()(complex64_t x) { - return std::exp(x); - } -}; - -struct Expm1 { - template - T operator()(T x) { - return expm1(x); - } -}; - -struct Floor { - template - T operator()(T x) { - return std::floor(x); - } - int8_t operator()(int8_t x) { - return x; - } - int16_t operator()(int16_t x) { - return x; - } - int32_t operator()(int32_t x) { - return x; - } - int64_t operator()(int64_t x) { - return x; - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct Imag { - template - T operator()(T x) { - return std::imag(x); - } -}; - -struct Log { - template - T operator()(T x) { - return std::log(x); - } -}; - -struct Log2 { - template - T operator()(T x) { - return std::log2(x); - } -}; - -struct Log10 { - template - T operator()(T x) { - return std::log10(x); - } -}; - -struct Log1p { - template - T operator()(T x) { - return log1p(x); - } -}; - -struct LogicalNot { - template - T operator()(T x) { - return !x; - } -}; - -struct Negative { - template - T operator()(T x) { - return -x; - } -}; - -struct Real { - template - T operator()(T x) { - return std::real(x); - } -}; - -struct Round { - template - T operator()(T x) { - return std::rint(x); - } - - complex64_t operator()(complex64_t x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - -struct Sigmoid { - template - T operator()(T x) { - auto one = static_cast(1.0); - return one / (one + fast_exp(-x)); - } -}; - -struct Sign { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } - - complex64_t operator()(complex64_t x) { - return x == complex64_t(0) ? x : x / std::abs(x); - } -}; - -struct Sin { - template - T operator()(T x) { - return std::sin(x); - } -}; - -struct Sinh { - template - T operator()(T x) { - return std::sinh(x); - } -}; - -struct Square { - template - T operator()(T x) { - return x * x; - } -}; - -struct Sqrt { - template - T operator()(T x) { - return std::sqrt(x); - } -}; - -struct Rsqrt { - template - T operator()(T x) { - return static_cast(1.0) / std::sqrt(x); - } -}; - -struct Tan { - template - T operator()(T x) { - return std::tan(x); - } -}; - -struct Tanh { - template - T operator()(T x) { - return std::tanh(x); - } -}; - -struct Add { - template - T operator()(T x, T y) { - return x + y; - } -}; - -struct Divide { - template - T operator()(T x, T y) { - return x / y; - } -}; - -struct Remainder { - template - std::enable_if_t & !std::is_signed_v, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } - - template - std::enable_if_t & std::is_signed_v, T> operator()( - T numerator, - T denominator) { - auto r = numerator % denominator; - if (r != 0 && (r < 0 != denominator < 0)) - r += denominator; - return r; - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - auto r = std::fmod(numerator, denominator); - if (r != 0 && (r < 0 != denominator < 0)) { - r += denominator; - } - return r; - } - - complex64_t operator()(complex64_t numerator, complex64_t denominator) { - return numerator % denominator; - } -}; - -struct Equal { - template - bool operator()(T x, T y) { - return x == y; - } -}; - -struct NaNEqual { - template - bool operator()(T x, T y) { - if constexpr (std::is_integral_v) { - // isnan always returns false for integers, and MSVC refuses to compile. - return x == y; - } else { - return x == y || (std::isnan(x) && std::isnan(y)); - } - } -}; - -struct Greater { - template - bool operator()(T x, T y) { - return x > y; - } -}; - -struct GreaterEqual { - template - bool operator()(T x, T y) { - return x >= y; - } -}; - -struct Less { - template - bool operator()(T x, T y) { - return x < y; - } -}; - -struct LessEqual { - template - bool operator()(T x, T y) { - return x <= y; - } -}; - -struct Maximum { - template - std::enable_if_t, T> operator()(T x, T y) { - return (x > y) ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - } -}; - -struct Minimum { - template - std::enable_if_t, T> operator()(T x, T y) { - return x < y ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return x < y ? x : y; - } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = Maximum()(x, y); - auto minval = Minimum()(x, y); - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(fast_exp(minval - maxval))); - } -}; - -struct Multiply { - template - T operator()(T x, T y) { - return x * y; - } -}; - -struct NotEqual { - template - bool operator()(T x, T y) { - return x != y; - } -}; - -struct Power { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - -struct Subtract { - template - T operator()(T x, T y) { - return x - y; - } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { - return x && y; - } -}; - -struct LogicalOr { - template - T operator()(T x, T y) { - return x || y; - } -}; - -struct Select { - template - T operator()(bool condition, T x, T y) { - return condition ? x : y; - } -}; - -struct BitwiseAnd { - template - T operator()(T x, T y) { - return x & y; - } -}; - -struct BitwiseOr { - template - T operator()(T x, T y) { - return x | y; - } -}; - -struct BitwiseXor { - template - T operator()(T x, T y) { - return x ^ y; - } -}; - -struct LeftShift { - template - T operator()(T x, T y) { - return x << y; - } -}; - -struct RightShift { - template - T operator()(T x, T y) { - return x >> y; - } -}; - -} // namespace mlx::core::detail diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 9ea015cd5..8cd0763d8 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -9,10 +9,9 @@ #include "mlx/allocator.h" #include "mlx/backend/common/arange.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/ops.h" +#include "mlx/backend/common/load.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/threefry.h" -#include "mlx/backend/common/unary.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -58,112 +57,64 @@ int64_t compute_dynamic_offset( } } -void Abs::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (issubdtype(in.dtype(), unsignedinteger)) { - // No-op for unsigned types - out.copy_shared_buffer(in); - } else { - unary(in, out, detail::Abs()); - } +void AsStrided::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void Broadcast::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void BroadcastAxes::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void Copy::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void CustomTransforms::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} +void Depends::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} +void ExpandDims::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void NumberOfElements::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void Slice::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void Split::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} +void Squeeze::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void StopGradient::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} +void Transpose::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); } -void Arange::eval(const std::vector& inputs, array& out) { +void Arange::eval_cpu(const std::vector& inputs, array& out) { arange(inputs, out, start_, step_); } -void ArcCos::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcCos()); - } else { - throw std::invalid_argument( - "[arccos] Cannot compute inverse cosine of elements in array" - " with non floating point type."); - } -} - -void ArcCosh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcCosh()); - } else { - throw std::invalid_argument( - "[arccosh] Cannot compute inverse hyperbolic cosine of elements in" - " array with non floating point type."); - } -} - -void ArcSin::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcSin()); - } else { - throw std::invalid_argument( - "[arcsin] Cannot compute inverse sine of elements in array" - " with non floating point type."); - } -} - -void ArcSinh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcSinh()); - } else { - throw std::invalid_argument( - "[arcsinh] Cannot compute inverse hyperbolic sine of elements in" - " array with non floating point type."); - } -} - -void ArcTan::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcTan()); - } else { - throw std::invalid_argument( - "[arctan] Cannot compute inverse tangent of elements in array" - " with non floating point type."); - } -} - -void ArcTanh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::ArcTanh()); - } else { - throw std::invalid_argument( - "[arctanh] Cannot compute inverse hyperbolic tangent of elements in" - " array with non floating point type."); - } -} - -void AsType::eval(const std::vector& inputs, array& out) { +void AsType::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; copy(in, out, ctype); } -void Ceil::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Ceil()); - } else { - // No-op integer types - out.copy_shared_buffer(in); - } -} - -void Concatenate::eval(const std::vector& inputs, array& out) { +void Concatenate::eval_cpu(const std::vector& inputs, array& out) { std::vector sizes; sizes.push_back(0); for (auto& p : inputs) { @@ -187,17 +138,6 @@ void Concatenate::eval(const std::vector& inputs, array& out) { } } -void Conjugate::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == complex64) { - unary_fp(in, out, detail::Conjugate()); - } else { - throw std::invalid_argument( - "[conjugate] conjugate must be called on complex input."); - } -} - void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -209,94 +149,6 @@ void Contiguous::eval_cpu(const std::vector& inputs, array& out) { } } -void Cos::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Cos()); - } else { - throw std::invalid_argument( - "[cos] Cannot compute cosine of elements in array" - " with non floating point type."); - } -} - -void Cosh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Cosh()); - } else { - throw std::invalid_argument( - "[cosh] Cannot compute hyperbolic cosine of elements in array" - " with non floating point type."); - } -} - -void Erf::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - switch (out.dtype()) { - case float32: - unary_op(in, out, detail::Erf()); - break; - case float16: - unary_op(in, out, detail::Erf()); - break; - case bfloat16: - unary_op(in, out, detail::Erf()); - break; - default: - throw std::invalid_argument( - "[erf] Error function only defined for arrays" - " with real floating point type."); - } -} - -void ErfInv::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - switch (out.dtype()) { - case float32: - unary_op(in, out, detail::ErfInv()); - break; - case float16: - unary_op(in, out, detail::ErfInv()); - break; - case bfloat16: - unary_op(in, out, detail::ErfInv()); - break; - default: - throw std::invalid_argument( - "[erf_inv] Inverse error function only defined for arrays" - " with real floating point type."); - } -} - -void Exp::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Exp()); - } else { - throw std::invalid_argument( - "[exp] Cannot exponentiate elements in array" - " with non floating point type."); - } -} - -void Expm1::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Expm1()); - } else { - throw std::invalid_argument( - "[expm1] Cannot exponentiate elements in array" - " with non floating point type."); - } -} - void Flatten::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } @@ -305,18 +157,7 @@ void Unflatten::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } -void Floor::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Floor()); - } else { - // No-op integer types - out.copy_shared_buffer(in); - } -} - -void Full::eval(const std::vector& inputs, array& out) { +void Full::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; assert(in.dtype() == out.dtype()); @@ -331,57 +172,14 @@ void Full::eval(const std::vector& inputs, array& out) { copy(in, out, ctype); } -void Imag::eval_cpu(const std::vector& inputs, array& out) { - unary_op(inputs[0], out, detail::Imag()); +void Load::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + load(out, offset_, reader_, swap_endianness_); } -void Log::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - switch (base_) { - case Base::e: - unary_fp(in, out, detail::Log()); - break; - case Base::two: - unary_fp(in, out, detail::Log2()); - break; - case Base::ten: - unary_fp(in, out, detail::Log10()); - break; - } - } else { - throw std::invalid_argument( - "[log] Cannot compute log of elements in array with" - " non floating point type."); - } -} - -void Log1p::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Log1p()); - } else { - throw std::invalid_argument( - "[log1p] Cannot compute log of elements in array with" - " non floating point type."); - } -} - -void LogicalNot::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - unary(in, out, detail::LogicalNot()); -} - -void Negative::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - unary(in, out, detail::Negative()); -} - -void Pad::eval(const std::vector& inputs, array& out) { +void Pad::eval_cpu(const std::vector& inputs, array& out) { // Inputs must be base input array and scalar val array assert(inputs.size() == 2); auto& in = inputs[0]; @@ -412,7 +210,7 @@ void Pad::eval(const std::vector& inputs, array& out) { copy_inplace(in, out_slice, CopyType::GeneralGeneral); } -void RandomBits::eval(const std::vector& inputs, array& out) { +void RandomBits::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) @@ -460,71 +258,10 @@ void RandomBits::eval(const std::vector& inputs, array& out) { } } -void Real::eval_cpu(const std::vector& inputs, array& out) { - unary_op(inputs[0], out, detail::Real()); -} - void Reshape::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } -void Round::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Round()); - } else { - // No-op integer types - out.copy_shared_buffer(in); - } -} - -void Sigmoid::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Sigmoid()); - } else { - throw std::invalid_argument( - "[sigmoid] Cannot sigmoid of elements in array with" - " non floating point type."); - } -} - -void Sign::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (in.dtype() == bool_) { - out.copy_shared_buffer(in); - } else { - unary(in, out, detail::Sign()); - } -} - -void Sin::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Sin()); - } else { - throw std::invalid_argument( - "[sin] Cannot compute sine of elements in array" - " with non floating point type."); - } -} - -void Sinh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Sinh()); - } else { - throw std::invalid_argument( - "[sinh] Cannot compute hyperbolic sine of elements in array" - " with non floating point type."); - } -} - void Slice::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (out.size() == 0) { @@ -596,7 +333,7 @@ void DynamicSliceUpdate::eval_cpu( /* CopyType ctype = */ CopyType::GeneralGeneral); } -void SliceUpdate::eval(const std::vector& inputs, array& out) { +void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { out.set_data(nullptr); @@ -632,46 +369,6 @@ void SliceUpdate::eval(const std::vector& inputs, array& out) { /* CopyType ctype = */ CopyType::GeneralGeneral); } -void Square::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - unary(in, out, detail::Square()); -} - -void Sqrt::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - if (recip_) { - unary_fp(in, out, detail::Rsqrt()); - } else { - unary_fp(in, out, detail::Sqrt()); - } -} - -void Tan::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Tan()); - } else { - throw std::invalid_argument( - "[tan] Cannot compute tangent of elements in array" - " with non floating point type."); - } -} - -void Tanh::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, detail::Tanh()); - } else { - throw std::invalid_argument( - "[tanh] Cannot compute hyperbolic tangent of elements in array" - " with non floating point type."); - } -} - void View::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/common/qrf.cpp index 21e9c71f1..1c28eec26 100644 --- a/mlx/backend/common/qrf.cpp +++ b/mlx/backend/common/qrf.cpp @@ -149,7 +149,9 @@ void qrf_impl(const array& a, array& q, array& r) { allocator::free(tau); } -void QRF::eval(const std::vector& inputs, std::vector& outputs) { +void QRF::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { if (!(inputs[0].dtype() == float32)) { throw std::runtime_error("[QRF::eval] only supports float32."); } diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 4b8bbdb89..e0883f490 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -3,7 +3,7 @@ #include #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/ops.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -151,6 +151,78 @@ void _qmm_t( } } +template +simd::Simd extract_bits_simd(const uint32_t* w) { + constexpr int bitmask = (1 << bits) - 1; + simd::Simd wi; + if constexpr (bits == 4 && S == 8) { + constexpr std::array shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}}; + auto shifts(*(simd::Simd*)&shifts_); + wi = simd::Simd(*w); + wi = wi >> shifts; + wi = wi & bitmask; + } else if constexpr (bits == 8 && S == 8) { + constexpr std::array shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}}; + auto shifts(*(simd::Simd*)&shifts_); + auto l = simd::Simd(*w++); + auto r = simd::Simd(*w); + wi = simd::Simd(l, r); + wi = wi >> shifts; + wi = wi & bitmask; + } else { + // Appease compiler.. but should never get here + throw std::runtime_error("Unsupported combination for simd qmm."); + } + return wi; +} + +template +void _qmm_t_simd( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K) { + constexpr int pack_factor = 32 / bits; + constexpr int packs_in_group = group_size / pack_factor; + constexpr int S = simd::max_size; + static_assert( + S % pack_factor == 0, "SIMD size must be divisible by pack factor"); + constexpr int packs_per_simd = S / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const T* scales_local = scales; + const T* biases_local = biases; + + for (int n = 0; n < N; n++) { + simd::Simd acc(0); + auto x_local = x; + for (int k = 0; k < K; k += group_size) { + T scale = *scales_local++; + T bias = *biases_local++; + + for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) { + auto wf = simd::Simd(extract_bits_simd(w_local)); + w_local += packs_per_simd; + wf = wf * scale; + wf = wf + bias; + simd::Simd x_simd = simd::load(x_local); + acc = acc + x_simd * wf; + x_local += S; + } + } + + *result = T(simd::sum(acc)); + result++; + } + x += K; + } +} + template void _qmm_dispatch_transpose( T* result, @@ -163,9 +235,14 @@ void _qmm_dispatch_transpose( int K, bool transposed_w) { if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); + // the simd size must be a multiple of the number of elements per word + if constexpr (32 % bits == 0 && simd::max_size % (32 / bits) == 0) { + _qmm_t_simd(result, x, w, scales, biases, M, N, K); + } else { + _qmm_t(result, x, w, scales, biases, M, N, K); + } } else { - return _qmm(result, x, w, scales, biases, M, N, K); + _qmm(result, x, w, scales, biases, M, N, K); } } @@ -249,13 +326,13 @@ void _qmm_dispatch( int group_size, bool transposed_w) { int K = x.shape(-1); - int M = x.shape(-2); + int M = x.ndim() > 1 ? x.shape(-2) : 1; int N = out.shape(-1); int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; - int batch_size = x.size() / x.shape(-1) / x.shape(-2); + int batch_size = x.size() / (K * M); for (int i = 0; i < batch_size; i++) { switch (x.dtype()) { case float32: @@ -384,7 +461,7 @@ void _bs_qmm_dispatch( } // namespace -void QuantizedMatmul::eval(const std::vector& inputs, array& out) { +void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 4); auto& x_pre = inputs[0]; @@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector& inputs, array& out) { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); } -void GatherQMM::eval(const std::vector& inputs, array& out) { +void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 6); auto& x_pre = inputs[0]; diff --git a/mlx/backend/common/select.cpp b/mlx/backend/common/select.cpp index 1daa771b3..04c28ef04 100644 --- a/mlx/backend/common/select.cpp +++ b/mlx/backend/common/select.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/binary_ops.h" #include "mlx/backend/common/ternary.h" #include "mlx/primitives.h" @@ -61,7 +62,7 @@ void select_op( } // namespace -void Select::eval(const std::vector& inputs, array& out) { +void Select::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); const auto& condition = inputs[0]; const auto& a = inputs[1]; diff --git a/mlx/backend/common/simd/accelerate_fp16_simd.h b/mlx/backend/common/simd/accelerate_fp16_simd.h new file mode 100644 index 000000000..7fa5c9467 --- /dev/null +++ b/mlx/backend/common/simd/accelerate_fp16_simd.h @@ -0,0 +1,56 @@ +#pragma once + +#include "mlx/backend/common/simd/base_simd.h" + +#if MLX_SIMD_LIBRARY_VERSION < 6 +#include "mlx/backend/common/simd/neon_fp16_simd.h" +#endif + +namespace mlx::core::simd { + +#if MLX_SIMD_LIBRARY_VERSION >= 6 +constexpr int N = 8; +template +struct ScalarT { + using v = _Float16; +}; +#endif + +template <> +static constexpr int max_size = N; + +#define SIMD_FP16_DEFAULT_UNARY(op) \ + template <> \ + inline Simd op(Simd v) { \ + Simd in = v; \ + return op(in); \ + } + +SIMD_FP16_DEFAULT_UNARY(acos) +SIMD_FP16_DEFAULT_UNARY(acosh) +SIMD_FP16_DEFAULT_UNARY(asin) +SIMD_FP16_DEFAULT_UNARY(asinh) +SIMD_FP16_DEFAULT_UNARY(atan) +SIMD_FP16_DEFAULT_UNARY(atanh) +SIMD_FP16_DEFAULT_UNARY(cosh) +SIMD_FP16_DEFAULT_UNARY(expm1) +SIMD_FP16_DEFAULT_UNARY(log) +SIMD_FP16_DEFAULT_UNARY(log2) +SIMD_FP16_DEFAULT_UNARY(log10) +SIMD_FP16_DEFAULT_UNARY(log1p) +SIMD_FP16_DEFAULT_UNARY(sinh) +SIMD_FP16_DEFAULT_UNARY(tan) +SIMD_FP16_DEFAULT_UNARY(tanh) + +#define SIMD_FP16_DEFAULT_BINARY(op) \ + template <> \ + inline Simd op(Simd x, Simd y) { \ + Simd a = x; \ + Simd b = y; \ + return op(a, b); \ + } +SIMD_FP16_DEFAULT_BINARY(atan2) +SIMD_FP16_DEFAULT_BINARY(remainder) +SIMD_FP16_DEFAULT_BINARY(pow) + +} // namespace mlx::core::simd diff --git a/mlx/backend/common/simd/accelerate_simd.h b/mlx/backend/common/simd/accelerate_simd.h new file mode 100644 index 000000000..443a8f617 --- /dev/null +++ b/mlx/backend/common/simd/accelerate_simd.h @@ -0,0 +1,291 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "mlx/backend/common/simd/base_simd.h" + +// There seems to be a bug in sims/base.h +// __XROS_2_0 is not defined, the expression evaluates +// to true instead of false setting the SIMD library +// higher than it should be even on macOS < 15 +#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \ + __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __TV_OS_VERSION_MIN_REQUIRED >= 180000 +#define MLX_SIMD_LIBRARY_VERSION 6 +#else +#define MLX_SIMD_LIBRARY_VERSION 5 +#endif + +namespace mlx::core::simd { + +// Apple simd namespace +namespace asd = ::simd; + +// This indirection is needed to remap certain types to ones that accelerate +// SIMD can handle +template +struct ScalarT { + using v = T; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = unsigned long; +}; +template +struct ScalarT { + using v = long; +}; + +template +struct Simd { + static constexpr int size = N; + using scalar_t = typename ScalarT::v; + + Simd() {} + + template + Simd(Simd other) : value(asd::convert(other.value)) {} + + template + Simd(U v) : value(v){}; + + Simd(Simd x, Simd y) { + value = asd::make::packed_t>( + x.value, y.value); + }; + + T operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + T& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + typename asd::Vector::packed_t value; +}; + +// Values chosen based on benchmarks on M3 Max +// TODO: consider choosing these more optimally +template <> +static constexpr int max_size = 16; +template <> +static constexpr int max_size = 16; +template <> +static constexpr int max_size = 8; +template <> +static constexpr int max_size = 4; +template <> +static constexpr int max_size = 16; +template <> +static constexpr int max_size = 16; +template <> +static constexpr int max_size = 8; +template <> +static constexpr int max_size = 4; +template <> +static constexpr int max_size = 8; +template <> +static constexpr int max_size = 4; + +#define SIMD_DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd v) { \ + return op(v.value); \ + } + +SIMD_DEFAULT_UNARY(abs, asd::abs) +SIMD_DEFAULT_UNARY(floor, asd::floor) +SIMD_DEFAULT_UNARY(acos, asd::acos) +SIMD_DEFAULT_UNARY(acosh, asd::acosh) +SIMD_DEFAULT_UNARY(asin, asd::asin) +SIMD_DEFAULT_UNARY(asinh, asd::asinh) +SIMD_DEFAULT_UNARY(atan, asd::atan) +SIMD_DEFAULT_UNARY(atanh, asd::atanh) +SIMD_DEFAULT_UNARY(ceil, asd::ceil) +SIMD_DEFAULT_UNARY(cosh, asd::cosh) +SIMD_DEFAULT_UNARY(expm1, asd::expm1) +SIMD_DEFAULT_UNARY(log, asd::log) +SIMD_DEFAULT_UNARY(log2, asd::log2) +SIMD_DEFAULT_UNARY(log10, asd::log10) +SIMD_DEFAULT_UNARY(log1p, asd::log1p) +SIMD_DEFAULT_UNARY(rint, asd::rint) +SIMD_DEFAULT_UNARY(sinh, asd::sinh) +SIMD_DEFAULT_UNARY(sqrt, asd::sqrt) +SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt) +SIMD_DEFAULT_UNARY(recip, asd::recip) +SIMD_DEFAULT_UNARY(tan, asd::tan) +SIMD_DEFAULT_UNARY(tanh, asd::tanh) + +template +Simd operator-(Simd v) { + return -v.value; +} + +template +Simd isnan(Simd v) { + return asd::convert(v.value != v.value); +} + +// No simd_boolN in accelerate, use int8_t instead +template +Simd operator!(Simd v) { + return asd::convert(!v.value); +} + +#define SIMD_DEFAULT_BINARY(OP) \ + template \ + Simd operator OP(Simd x, U y) { \ + return asd::convert::scalar_t>(x.value OP y); \ + } \ + template \ + Simd operator OP(T1 x, Simd y) { \ + return asd::convert::scalar_t>(x OP y.value); \ + } \ + template \ + Simd operator OP(Simd x, Simd y) { \ + return asd::convert::scalar_t>(x.value OP y.value); \ + } + +SIMD_DEFAULT_BINARY(+) +SIMD_DEFAULT_BINARY(-) +SIMD_DEFAULT_BINARY(/) +SIMD_DEFAULT_BINARY(*) +SIMD_DEFAULT_BINARY(<<) +SIMD_DEFAULT_BINARY(>>) +SIMD_DEFAULT_BINARY(|) +SIMD_DEFAULT_BINARY(^) +SIMD_DEFAULT_BINARY(&) +SIMD_DEFAULT_BINARY(&&) +SIMD_DEFAULT_BINARY(||) + +#define SIMD_DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, U b) { \ + return asd::convert(a.value OP b); \ + } \ + template \ + Simd operator OP(T a, Simd b) { \ + return asd::convert(a OP b.value); \ + } \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return asd::convert(a.value OP b.value); \ + } + +SIMD_DEFAULT_COMPARISONS(>) +SIMD_DEFAULT_COMPARISONS(<) +SIMD_DEFAULT_COMPARISONS(>=) +SIMD_DEFAULT_COMPARISONS(<=) +SIMD_DEFAULT_COMPARISONS(==) +SIMD_DEFAULT_COMPARISONS(!=) + +template +Simd atan2(Simd a, Simd b) { + return asd::atan2(a.value, b.value); +} + +template +Simd maximum(Simd a, Simd b) { + // TODO add isnan + return asd::max(a.value, b.value); +} + +template +Simd minimum(Simd a, Simd b) { + // TODO add isnan + return asd::min(a.value, b.value); +} + +template +Simd remainder(Simd a, Simd b) { + Simd r; + if constexpr (!std::is_integral_v) { + r = asd::remainder(a.value, b.value); + } else { + r = a - b * (a / b); + } + if constexpr (std::is_signed_v) { + auto mask = r != 0 && (r < 0 != b < 0); + r = select(mask, r + b, r); + } + return r; +} + +template +Simd select(Simd mask, Simd x, Simd y) { + if constexpr (sizeof(T1) == 1) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 2) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 4) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } +} + +template +Simd pow(Simd base, Simd exp) { + if constexpr (!std::is_integral_v) { + return asd::pow(base.value, exp.value); + } else { + Simd res = 1; + while (any(exp)) { + res = select(exp & 1, res * base, res); + base = select(exp, base * base, base); + exp = exp >> 1; + } + return res; + } +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return asd::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return asd::muladd(x.value, y.value, Simd(z).value); +} + +// Reductions + +template +bool any(Simd x) { + return asd::any(x.value); +} +template +T sum(Simd x) { + return asd::reduce_add(x.value); +} +template +T max(Simd x) { + return asd::reduce_max(x.value); +} +template +T min(Simd x) { + return asd::reduce_min(x.value); +} + +} // namespace mlx::core::simd + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "mlx/backend/common/simd/accelerate_fp16_simd.h" +#endif diff --git a/mlx/backend/common/simd/base_simd.h b/mlx/backend/common/simd/base_simd.h new file mode 100644 index 000000000..d7e4fdc3d --- /dev/null +++ b/mlx/backend/common/simd/base_simd.h @@ -0,0 +1,252 @@ +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::simd { +template +struct Simd; + +template +static constexpr int max_size = 1; + +template +struct Simd { + static constexpr int size = 1; + T value; + Simd() {} + template + Simd(Simd v) : value(v.value) {} + template + Simd(U v) : value(v) {} +}; + +template +Simd load(const T* x) { + return *(Simd*)x; +} + +template +void store(T* dst, Simd x) { + // Maintain invariant that bool is either 0 or 1 as + // simd comparison ops set all bits in the result to 1 + if constexpr (std::is_same_v && N > 1) { + x = x & 1; + } + *(Simd*)dst = x; +} + +template +constexpr bool is_complex = false; + +template +constexpr bool is_complex().real())>> = + true; + +template +Simd rint(Simd in) { + if constexpr (is_complex) { + return Simd{ + T{std::rint(in.value.real()), std::rint(in.value.imag())}}; + } else { + return Simd{std::rint(in.value)}; + } +} + +template +Simd rsqrt(Simd in) { + return T(1.0) / sqrt(in); +} + +template +Simd recip(Simd in) { + return T(1.0) / in; +} + +#define DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd in) { \ + return op(in.value); \ + } + +DEFAULT_UNARY(operator-, std::negate{}) +DEFAULT_UNARY(operator!, std::logical_not{}) +DEFAULT_UNARY(abs, std::abs) +DEFAULT_UNARY(acos, std::acos) +DEFAULT_UNARY(acosh, std::acosh) +DEFAULT_UNARY(asin, std::asin) +DEFAULT_UNARY(asinh, std::asinh) +DEFAULT_UNARY(atan, std::atan) +DEFAULT_UNARY(atanh, std::atanh) +DEFAULT_UNARY(ceil, std::ceil) +DEFAULT_UNARY(conj, std::conj) +DEFAULT_UNARY(cosh, std::cosh) +DEFAULT_UNARY(expm1, std::expm1) +DEFAULT_UNARY(floor, std::floor) +DEFAULT_UNARY(log, std::log) +DEFAULT_UNARY(log2, std::log2) +DEFAULT_UNARY(log10, std::log10) +DEFAULT_UNARY(log1p, std::log1p) +DEFAULT_UNARY(sinh, std::sinh) +DEFAULT_UNARY(sqrt, std::sqrt) +DEFAULT_UNARY(tan, std::tan) +DEFAULT_UNARY(tanh, std::tanh) + +template +auto real(Simd in) -> Simd { + return std::real(in.value); +} +template +auto imag(Simd in) -> Simd { + return std::imag(in.value); +} +template +Simd isnan(Simd in) { + return std::isnan(in.value); +} + +#define DEFAULT_BINARY(OP) \ + template \ + auto operator OP(Simd a, Simd b) \ + ->Simd { \ + return a.value OP b.value; \ + } \ + template \ + auto operator OP(T1 a, Simd b)->Simd { \ + return a OP b.value; \ + } \ + template \ + auto operator OP(Simd a, T2 b)->Simd { \ + return a.value OP b; \ + } + +DEFAULT_BINARY(+) +DEFAULT_BINARY(-) +DEFAULT_BINARY(*) +DEFAULT_BINARY(/) +DEFAULT_BINARY(<<) +DEFAULT_BINARY(>>) +DEFAULT_BINARY(|) +DEFAULT_BINARY(^) +DEFAULT_BINARY(&) +DEFAULT_BINARY(&&) +DEFAULT_BINARY(||) + +template +Simd remainder(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + T r; + if constexpr (std::is_integral_v) { + r = a % b; + } else { + r = std::remainder(a, b); + } + if constexpr (std::is_signed_v) { + if (r != 0 && (r < 0 != b < 0)) { + r += b; + } + } + return r; +} + +template +Simd maximum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a > b) ? a : b; +} + +template +Simd minimum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a < b) ? a : b; +} + +template +Simd pow(Simd a, Simd b) { + T base = a.value; + T exp = b.value; + if constexpr (!std::is_integral_v) { + return std::pow(base, exp); + } else { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +} + +template +Simd atan2(Simd a, Simd b) { + return std::atan2(a.value, b.value); +} + +#define DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return a.value OP b.value; \ + } \ + template \ + Simd operator OP(T1 a, Simd b) { \ + return a OP b.value; \ + } \ + template \ + Simd operator OP(Simd a, T2 b) { \ + return a.value OP b; \ + } + +DEFAULT_COMPARISONS(>) +DEFAULT_COMPARISONS(<) +DEFAULT_COMPARISONS(>=) +DEFAULT_COMPARISONS(<=) +DEFAULT_COMPARISONS(==) +DEFAULT_COMPARISONS(!=) + +template +Simd select(Simd mask, Simd x, Simd y) { + return mask.value ? x.value : y.value; +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return std::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return std::fma(x.value, y.value, Simd(z).value); +} + +// Reductions +#define DEFAULT_REDUCTION(name, type) \ + template \ + type name(Simd x) { \ + return x.value; \ + } + +DEFAULT_REDUCTION(max, T) +DEFAULT_REDUCTION(min, T) +DEFAULT_REDUCTION(sum, T) +DEFAULT_REDUCTION(any, bool) +DEFAULT_REDUCTION(all, bool) + +} // namespace mlx::core::simd diff --git a/mlx/backend/common/simd/math.h b/mlx/backend/common/simd/math.h new file mode 100644 index 000000000..c7061b2b1 --- /dev/null +++ b/mlx/backend/common/simd/math.h @@ -0,0 +1,193 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/simd/type.h" + +namespace mlx::core::simd { + +constexpr float inf = std::numeric_limits::infinity(); + +/** + * Compute exp(x) in an optimizer friendly way as follows: + * + * First change the problem to computing 2**y where y = x / ln(2). + * + * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part + * `ipart` and y2 is fractional part. For the integer part we perform bit + * shifting and for the fractional part we use a polynomial approximation. + * + * The algorithm and constants of the polynomial taken from + * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them + * from Cephes math library. + * + * Note: The implementation below is a general fast exp. There could be faster + * implementations for numbers strictly < 0. + */ +template +Simd exp(Simd in) { + if constexpr (is_complex) { + return Simd{std::exp(in.value)}; + } else { + Simd x_init = in; + auto x = x_init * 1.442695f; // multiply with log_2(e) + Simd ipart, fpart; + ipart = floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = fma(x, fpart, 1.339887440266574e-3f); + x = fma(x, fpart, 9.618437357674640e-3f); + x = fma(x, fpart, 5.550332471162809e-2f); + x = fma(x, fpart, 2.402264791363012e-1f); + x = fma(x, fpart, 6.931472028550421e-1f); + x = fma(x, fpart, 1.000000000000000f); + + // generate 2**ipart in the floating point representation using integer + // bitshifting + Simd epart = (Simd(ipart) + 127) << 23; + + // Deal with NaN and Inf + auto result = select(isnan(x_init), x_init, (*(Simd*)&epart) * x); + result = select(x_init > 88.0f, Simd(inf), result); + result = select(x_init < -88.0f, Simd(0), result); + return Simd(result); + } +} + +/* Implementation from: + * https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357 + * which originally came from the Cephes math library. + */ +template +Simd sincos(Simd in) { + auto sign_mask_sin = in < 0; + in = abs(in); + Simd x = in; + + // scale by 4/Pi + auto y = x * 1.27323954473516f; + + // store the integer part of y in mm0 + Simd emm2 = y; + + // j=(j+1) & (~1) (see the cephes sources) + emm2 = emm2 + 1; + emm2 = emm2 & ~1; + + y = emm2; + + // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4(-0.78515625f), x); + x = fma(y, Simd(-2.4187564849853515625e-4f), x); + x = fma(y, Simd(-3.77489497744594108e-8f), x); + + sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); + auto sign_mask_cos = ((emm2 - 2) & 4) != 0; + + // Evaluate the first polynom (0 <= x <= Pi/4) in y1, + // and the second polynom (Pi/4 <= x <= 0) in y2 + auto z = x * x; + + auto y1 = + fma(z, Simd(2.443315711809948e-5f), -1.388731625493765e-3f); + auto y2 = fma(z, Simd(-1.9515295891e-4f), 8.3321608736e-3f); + y1 = fma(y1, z, 4.166664568298827e-2f); + y2 = fma(y2, z, -1.6666654611e-1f); + y1 = y1 * z; + y2 = y2 * z; + y1 = y1 * z; + y2 = fma(x, y2, x); + y1 = fma(z, Simd(-0.5f), y1); + y1 = y1 + 1.0f; + + if constexpr (Sine) { + auto ys = select(poly_mask, y1, y2); + return select(sign_mask_sin, -ys, ys); + } else { + auto yc = select(poly_mask, y2, y1); + return select(sign_mask_cos, yc, -yc); + } +} + +template +Simd sin(Simd x) { + if constexpr (is_complex) { + return std::sin(x.value); + } else { + return sincos(x); + } +} + +template +Simd cos(Simd x) { + if constexpr (is_complex) { + return std::cos(x.value); + } else { + return sincos(x); + } +} + +template +Simd erf(Simd x) { + // https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175 + Simd v = x; + auto t = recip(fma(Simd(0.3275911f), abs(v), 1.0f)); + auto r = fma(Simd(1.061405429f), t, -1.453152027f); + r = fma(r, t, 1.421413741f); + r = fma(r, t, -0.284496736f); + r = fma(r, t, 0.254829592f); + auto e = -exp(-v * v); + auto result = Simd(fma(e * t, r, 1.0f)); + return select(x > 0, result, -result); +} + +template +Simd erfinv(Simd a_) { + Simd a = a_; + auto t = fma(a, 0.0f - a, 1.0f); + t = log(t); + auto lhs = [](auto t) { + Simd p; + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + }; + auto rhs = [](auto t) { + Simd p; + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + }; + auto thresh = 6.125f; + // Compute both branches and select if N > 1 + if constexpr (N == 1) { + if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793 + return a * lhs(t); + } else { // maximum ulp error = 2.35002 + return a * rhs(t); + } + } else { + return a * select(t > thresh, lhs(t), rhs(t)); + } +} + +} // namespace mlx::core::simd diff --git a/mlx/backend/common/simd/neon_fp16_simd.h b/mlx/backend/common/simd/neon_fp16_simd.h new file mode 100644 index 000000000..923e27776 --- /dev/null +++ b/mlx/backend/common/simd/neon_fp16_simd.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include "mlx/backend/common/simd/base_simd.h" + +namespace mlx::core::simd { + +constexpr int N = 8; + +template <> +struct Simd { + static constexpr int size = N; + using scalar_t = float16_t; + + Simd() {} + + template + Simd(U v) : value(vdupq_n_f16(v)){}; + + Simd(float16x8_t v) : value(v){}; + + Simd(Simd other) { + auto f32x4_a = *(float32x4_t*)(&other); + auto f32x4_b = *((float32x4_t*)(&other) + 1); + value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b); + }; + + Simd(Simd other) { + value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value)); + }; + + operator Simd() { + auto v = vcvtq_s16_f16(value); + return load((int16_t*)&v); + }; + + operator Simd() { + float32x4x2_t v; + v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value)); + v.val[1] = vcvt_high_f32_f16(value); + return load((float*)&v); + } + float16_t operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + float16_t& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + float16x8_t value; +}; + +#define DEFINE_NEON_UNARY_OP(name, op) \ + inline Simd name(Simd a) { \ + return Simd{op(a.value)}; \ + } + +DEFINE_NEON_UNARY_OP(abs, vabsq_f16) +DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16) +DEFINE_NEON_UNARY_OP(floor, vrndmq_f16) +DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16) +DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16) +DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16) +DEFINE_NEON_UNARY_OP(rint, vrndnq_f16) + +#define DEFINE_NEON_BINARY_OP(name, op) \ + inline Simd name(Simd a, Simd b) { \ + return op(a.value, b.value); \ + } \ + template \ + Simd name(Simd a, T b) { \ + return op(a.value, Simd(b).value); \ + } \ + template \ + Simd name(T a, Simd b) { \ + return op(Simd(a).value, b.value); \ + } + +inline Simd operator!(Simd v) { + auto out = vceqzq_f16(v.value); + return Simd(*(uint16_t*)&out); +} + +inline Simd operator-(Simd v) { + return vnegq_f16(v.value); +} + +DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16) +DEFINE_NEON_BINARY_OP(minimum, vminq_f16) +DEFINE_NEON_BINARY_OP(operator+, vaddq_f16) +DEFINE_NEON_BINARY_OP(operator-, vsubq_f16) +DEFINE_NEON_BINARY_OP(operator*, vmulq_f16) +DEFINE_NEON_BINARY_OP(operator/, vdivq_f16) + +#define DEFINE_NEON_COMPARISON(Op, op) \ + template \ + Simd operator Op(Simd a, T b) { \ + auto out = op(a.value, Simd(b).value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + template \ + Simd operator Op(T a, Simd b) { \ + auto out = op(Simd(a).value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + inline Simd operator Op( \ + Simd a, Simd b) { \ + auto out = op(a.value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } + +DEFINE_NEON_COMPARISON(==, vceqq_f16) +DEFINE_NEON_COMPARISON(>=, vcgeq_f16) +DEFINE_NEON_COMPARISON(<=, vcleq_f16) +DEFINE_NEON_COMPARISON(>, vcgtq_f16) +DEFINE_NEON_COMPARISON(<, vcltq_f16) + +template +Simd operator!=(Simd a, T b) { + return !(a == b); +} +template +Simd operator!=(T a, Simd b) { + return !(a == b); +} +inline Simd operator!=(Simd a, Simd b) { + return !(a == b); +} + +inline Simd operator||( + Simd a, + Simd b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(Simd a, T b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(T a, Simd b) { + return Simd((a != 0) || (b != 0)); +} +inline Simd operator&&( + Simd a, + Simd b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(Simd a, T b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(T a, Simd b) { + return Simd((a != 0) && (b != 0)); +} + +template <> +inline Simd isnan(Simd v) { + return v != v; +} + +template <> +inline Simd +clamp(Simd v, Simd min, Simd max) { + return minimum(maximum(v, min), max); +} + +template +Simd fma(Simd x, Simd y, T z) { + return vfmaq_f16(x.value, y.value, Simd(z).value); +} + +template +Simd +select(Simd mask, Simd x, Simd y) { + return vbslq_f16(Simd(mask).value, x.value, y.value); +} + +// Reductions +inline float16_t max(Simd x) { + float16x4_t y; + y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmax_f16(y, y); + y = vpmax_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t min(Simd x) { + float16x4_t y; + y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmin_f16(y, y); + y = vpmin_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t sum(Simd x) { + float16x4_t y; + y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpadd_f16(y, y); + y = vpadd_f16(y, y); + return vget_lane_f16(y, 0); +} + +} // namespace mlx::core::simd diff --git a/mlx/backend/common/simd/simd.h b/mlx/backend/common/simd/simd.h new file mode 100644 index 000000000..4b356a9e5 --- /dev/null +++ b/mlx/backend/common/simd/simd.h @@ -0,0 +1,4 @@ +#pragma once + +#include "mlx/backend/common/simd/math.h" +#include "mlx/backend/common/simd/type.h" diff --git a/mlx/backend/common/simd/type.h b/mlx/backend/common/simd/type.h new file mode 100644 index 000000000..23b71a1cf --- /dev/null +++ b/mlx/backend/common/simd/type.h @@ -0,0 +1,7 @@ +#pragma once + +#include "mlx/backend/common/simd/base_simd.h" + +#ifdef MLX_USE_ACCELERATE +#include "mlx/backend/common/simd/accelerate_simd.h" +#endif diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index ed4e3958b..2c7579930 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -4,61 +4,107 @@ #include #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { namespace { +using namespace mlx::core::simd; + template void softmax(const array& in, array& out) { + constexpr bool same_t = std::is_same_v; + constexpr int N = std::min(max_size, max_size); + const T* in_ptr = in.data(); T* out_ptr = out.data(); - int N = in.shape().back(); - int M = in.data_size() / N; + int M = in.shape().back(); + int L = in.data_size() / M; const T* current_in_ptr; T* current_out_ptr; - for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { + for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { // Find the maximum current_in_ptr = in_ptr; - AccT maximum = *current_in_ptr; - for (int j = 0; j < N; j++, current_in_ptr++) { - maximum = (maximum < *current_in_ptr) ? static_cast(*current_in_ptr) - : maximum; + Simd vmaximum(-std::numeric_limits::infinity()); + size_t s = M; + while (s >= N) { + Simd vals = load(current_in_ptr); + vmaximum = maximum(vals, vmaximum); + current_in_ptr += N; + s -= N; + } + + AccT maximum = max(vmaximum); + while (s-- > 0) { + maximum = std::max(maximum, static_cast(*current_in_ptr)); + current_in_ptr++; } // Compute the normalizer and the exponentials - AccT normalizer = 0; + Simd vnormalizer(0.0); current_out_ptr = out_ptr; current_in_ptr = in_ptr; - for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { - AccT expv = std::exp(*current_in_ptr - maximum); - normalizer += expv; - if constexpr (std::is_same::value) { - *current_out_ptr = expv; + s = M; + while (s >= N) { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum); + if constexpr (same_t) { + store(current_out_ptr, vexp); } + vnormalizer = vnormalizer + vexp; + current_in_ptr += N; + current_out_ptr += N; + s -= N; + } + AccT normalizer = sum(vnormalizer); + while (s-- > 0) { + AccT _exp = std::exp(*current_in_ptr - maximum); + if constexpr (same_t) { + *current_out_ptr = _exp; + } + normalizer += _exp; + current_in_ptr++; + current_out_ptr++; } normalizer = 1 / normalizer; // Normalize - current_in_ptr = in_ptr; current_out_ptr = out_ptr; - for (int j = 0; j < N; j++, current_out_ptr++) { - if constexpr (std::is_same::value) { + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + if constexpr (same_t) { + store( + current_out_ptr, + Simd(load(current_out_ptr) * normalizer)); + } else { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum) * normalizer; + store(current_out_ptr, Simd(vexp)); + current_in_ptr += N; + } + current_out_ptr += N; + s -= N; + } + while (s-- > 0) { + if constexpr (same_t) { *current_out_ptr *= normalizer; } else { - auto v = std::exp(*current_in_ptr - maximum); - *current_out_ptr = static_cast(v * normalizer); + AccT _exp = std::exp(*current_in_ptr - maximum); + *current_out_ptr = static_cast(_exp * normalizer); current_in_ptr++; } + current_out_ptr++; } } } } // namespace -void Softmax::eval(const std::vector& inputs, array& out) { +void Softmax::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // Make sure that the last dimension is contiguous @@ -97,7 +143,7 @@ void Softmax::eval(const std::vector& inputs, array& out) { case int16: case int32: case int64: - throw std::invalid_argument( + throw std::runtime_error( "Softmax is defined only for floating point types"); break; case float32: diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp index e2f6d48bd..1304186d6 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/common/sort.cpp @@ -287,7 +287,7 @@ void argpartition(const array& in, array& out, int axis, int kth) { } // namespace -void ArgSort::eval(const std::vector& inputs, array& out) { +void ArgSort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -321,7 +321,7 @@ void ArgSort::eval(const std::vector& inputs, array& out) { } } -void Sort::eval(const std::vector& inputs, array& out) { +void Sort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -355,7 +355,7 @@ void Sort::eval(const std::vector& inputs, array& out) { } } -void ArgPartition::eval(const std::vector& inputs, array& out) { +void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -389,7 +389,7 @@ void ArgPartition::eval(const std::vector& inputs, array& out) { } } -void Partition::eval(const std::vector& inputs, array& out) { +void Partition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 1a6f1b1ad..71c620db1 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -137,7 +137,9 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { } } -void SVD::eval(const std::vector& inputs, std::vector& outputs) { +void SVD::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { if (!(inputs[0].dtype() == float32)) { throw std::runtime_error("[SVD::eval] only supports float32."); } diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index eb513a12f..eaff8db00 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -3,8 +3,8 @@ #pragma once #include "mlx/allocator.h" #include "mlx/array.h" -#include "mlx/backend/common/ops.h" #include "mlx/backend/common/utils.h" + namespace mlx::core { namespace { diff --git a/mlx/backend/common/unary.cpp b/mlx/backend/common/unary.cpp new file mode 100644 index 000000000..be9fec715 --- /dev/null +++ b/mlx/backend/common/unary.cpp @@ -0,0 +1,285 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/common/unary_ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void Abs::eval_cpu(const std::vector& inputs, array& out) { + auto& in = inputs[0]; + if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) { + // No-op for unsigned types + out.copy_shared_buffer(in); + } else { + auto op = detail::Abs{}; + switch (out.dtype()) { + case int8: + unary_op(in, out, op); + break; + case int16: + unary_op(in, out, op); + break; + case int32: + unary_op(in, out, op); + break; + case int64: + unary_op(in, out, op); + break; + case float16: + unary_op(in, out, op); + break; + case float32: + unary_op(in, out, op); + break; + case bfloat16: + unary_op(in, out, op); + break; + case complex64: + unary_op(in, out, op); + break; + default: + throw std::runtime_error("[Abs] Called on unsigned type"); + } + } +} + +void ArcCos::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcCos()); +} + +void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcCosh()); +} + +void ArcSin::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcSin()); +} + +void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcSinh()); +} + +void ArcTan::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcTan()); +} + +void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::ArcTanh()); +} + +void Ceil::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (issubdtype(in.dtype(), inexact)) { + unary_fp(in, out, detail::Ceil()); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Conjugate::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + unary_op(inputs[0], out, detail::Conjugate()); +} + +void Cos::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Cos()); +} + +void Cosh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Cosh()); +} + +void Erf::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + switch (out.dtype()) { + case float32: + unary_op(in, out, detail::Erf()); + break; + case float16: + unary_op(in, out, detail::Erf()); + break; + case bfloat16: + unary_op(in, out, detail::Erf()); + break; + default: + throw std::invalid_argument( + "[erf] Error function only defined for arrays" + " with real floating point type."); + } +} + +void ErfInv::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + switch (out.dtype()) { + case float32: + unary_op(in, out, detail::ErfInv()); + break; + case float16: + unary_op(in, out, detail::ErfInv()); + break; + case bfloat16: + unary_op(in, out, detail::ErfInv()); + break; + default: + throw std::invalid_argument( + "[erf_inv] Inverse error function only defined for arrays" + " with real floating point type."); + } +} + +void Exp::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Exp()); +} + +void Expm1::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Expm1()); +} + +void Floor::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (issubdtype(in.dtype(), inexact)) { + unary_fp(in, out, detail::Floor()); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Imag::eval_cpu(const std::vector& inputs, array& out) { + unary_op(inputs[0], out, detail::Imag()); +} + +void Log::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + switch (base_) { + case Base::e: + unary_fp(in, out, detail::Log()); + break; + case Base::two: + unary_fp(in, out, detail::Log2()); + break; + case Base::ten: + unary_fp(in, out, detail::Log10()); + break; + } +} + +void Log1p::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Log1p()); +} + +void LogicalNot::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, detail::LogicalNot()); +} + +void Negative::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, detail::Negative()); +} + +void Real::eval_cpu(const std::vector& inputs, array& out) { + unary_op(inputs[0], out, detail::Real()); +} + +void Round::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (issubdtype(in.dtype(), inexact)) { + unary_fp(in, out, detail::Round()); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sigmoid::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Sigmoid()); +} + +void Sign::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == bool_) { + out.copy_shared_buffer(in); + } else { + unary(in, out, detail::Sign()); + } +} + +void Sin::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Sin()); +} + +void Sinh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Sinh()); +} + +void Square::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, detail::Square()); +} + +void Sqrt::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (recip_) { + unary_fp(in, out, detail::Rsqrt()); + } else { + unary_fp(in, out, detail::Sqrt()); + } +} + +void Tan::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Tan()); +} + +void Tanh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_fp(in, out, detail::Tanh()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 944f5034a..e38937d3b 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -4,6 +4,7 @@ #include "mlx/allocator.h" #include "mlx/array.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" #include "mlx/utils.h" @@ -38,8 +39,19 @@ void unary_op(const array& a, array& out, Op op) { if (a.flags().contiguous) { set_unary_output_data(a, out); U* dst = out.data(); - for (size_t i = 0; i < a.data_size(); ++i) { - dst[i] = op(a_ptr[i]); + constexpr int N = simd::max_size; + size_t size = a.data_size(); + while (size >= N) { + simd::store(dst, op(simd::load(a_ptr))); + size -= N; + a_ptr += N; + dst += N; + } + while (size > 0) { + *dst = op(*a_ptr); + size--; + dst++; + a_ptr++; } } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); diff --git a/mlx/backend/common/unary_ops.h b/mlx/backend/common/unary_ops.h new file mode 100644 index 000000000..11a69c2ca --- /dev/null +++ b/mlx/backend/common/unary_ops.h @@ -0,0 +1,108 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/common/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define SINGLE() \ + template \ + T operator()(T x) { \ + return (*this)(Simd(x)).value; \ + } + +#define DEFAULT_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x) { \ + return simd::op(x); \ + } \ + SINGLE() \ + }; + +DEFAULT_OP(Abs, abs) +DEFAULT_OP(ArcCos, acos) +DEFAULT_OP(ArcCosh, acosh) +DEFAULT_OP(ArcSin, asin) +DEFAULT_OP(ArcSinh, asinh) +DEFAULT_OP(ArcTan, atan) +DEFAULT_OP(ArcTanh, atanh) +DEFAULT_OP(Ceil, ceil) +DEFAULT_OP(Conjugate, conj) +DEFAULT_OP(Cos, cos) +DEFAULT_OP(Cosh, cosh) +DEFAULT_OP(Erf, erf) +DEFAULT_OP(ErfInv, erfinv) +DEFAULT_OP(Exp, exp) +DEFAULT_OP(Expm1, expm1) +DEFAULT_OP(Floor, floor); +DEFAULT_OP(Log, log); +DEFAULT_OP(Log2, log2); +DEFAULT_OP(Log10, log10); +DEFAULT_OP(Log1p, log1p); +DEFAULT_OP(LogicalNot, operator!) +DEFAULT_OP(Negative, operator-) +DEFAULT_OP(Round, rint); +DEFAULT_OP(Sin, sin) +DEFAULT_OP(Sinh, sinh) +DEFAULT_OP(Sqrt, sqrt) +DEFAULT_OP(Rsqrt, rsqrt) +DEFAULT_OP(Tan, tan) +DEFAULT_OP(Tanh, tanh) + +struct Imag { + template + Simd operator()(Simd x) { + return simd::imag(x); + } + SINGLE() +}; + +struct Real { + template + Simd operator()(Simd x) { + return simd::real(x); + } + SINGLE() +}; + +struct Sigmoid { + template + Simd operator()(Simd x) { + return 1.0f / (1.0f + simd::exp(-x)); + } + SINGLE() +}; + +struct Sign { + template + Simd operator()(Simd x) { + auto z = Simd{0}; + if constexpr (std::is_unsigned_v) { + return x != z; + } else if constexpr (std::is_same_v) { + return simd::select(x == z, x, Simd(x / simd::abs(x))); + } else { + return simd::select( + x < z, Simd{-1}, simd::select(x > z, Simd{1}, z)); + } + } + SINGLE() +}; + +struct Square { + template + Simd operator()(Simd x) { + return x * x; + } + SINGLE() +}; + +} // namespace mlx::core::detail diff --git a/mlx/primitives.h b/mlx/primitives.h index 89209e0f6..8158c88d6 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -163,9 +163,6 @@ class Abs : public UnaryPrimitive { DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Add : public UnaryPrimitive { @@ -180,9 +177,6 @@ class Add : public UnaryPrimitive { DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class AddMM : public UnaryPrimitive { @@ -226,8 +220,6 @@ class Arange : public UnaryPrimitive { double start_; double stop_; double step_; - - void eval(const std::vector& inputs, array& out); }; class ArcCos : public UnaryPrimitive { @@ -242,9 +234,6 @@ class ArcCos : public UnaryPrimitive { DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcCosh : public UnaryPrimitive { @@ -259,9 +248,6 @@ class ArcCosh : public UnaryPrimitive { DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcSin : public UnaryPrimitive { @@ -276,9 +262,6 @@ class ArcSin : public UnaryPrimitive { DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcSinh : public UnaryPrimitive { @@ -293,9 +276,6 @@ class ArcSinh : public UnaryPrimitive { DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcTan : public UnaryPrimitive { @@ -310,9 +290,6 @@ class ArcTan : public UnaryPrimitive { DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcTan2 : public UnaryPrimitive { @@ -327,9 +304,6 @@ class ArcTan2 : public UnaryPrimitive { DEFINE_PRINT(ArcTan2) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArcTanh : public UnaryPrimitive { @@ -344,9 +318,6 @@ class ArcTanh : public UnaryPrimitive { DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ArgPartition : public UnaryPrimitive { @@ -369,8 +340,6 @@ class ArgPartition : public UnaryPrimitive { private: int kth_; int axis_; - - void eval(const std::vector& inputs, array& out); }; class ArgReduce : public UnaryPrimitive { @@ -398,8 +367,6 @@ class ArgReduce : public UnaryPrimitive { private: ReduceType reduce_type_; int axis_; - - void eval(const std::vector& inputs, array& out); }; class ArgSort : public UnaryPrimitive { @@ -420,8 +387,6 @@ class ArgSort : public UnaryPrimitive { private: int axis_; - - void eval(const std::vector& inputs, array& out); }; class AsType : public UnaryPrimitive { @@ -443,8 +408,6 @@ class AsType : public UnaryPrimitive { private: Dtype dtype_; - - void eval(const std::vector& inputs, array& out); }; class AsStrided : public UnaryPrimitive { @@ -518,8 +481,6 @@ class BlockMaskedMM : public UnaryPrimitive { private: int block_size_; - - void eval(const std::vector& inputs, array& out); }; class GatherMM : public UnaryPrimitive { @@ -537,9 +498,6 @@ class GatherMM : public UnaryPrimitive { DEFINE_PRINT(GatherMM) DEFINE_DEFAULT_IS_EQUIVALENT() - - private: - void eval(const std::vector& inputs, array& out); }; class BroadcastAxes : public UnaryPrimitive { @@ -603,9 +561,6 @@ class Ceil : public UnaryPrimitive { DEFINE_PRINT(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Compiled : public Primitive { @@ -669,8 +624,6 @@ class Concatenate : public UnaryPrimitive { private: int axis_; - - void eval(const std::vector& inputs, array& out); }; class Conjugate : public UnaryPrimitive { @@ -684,9 +637,6 @@ class Conjugate : public UnaryPrimitive { DEFINE_PRINT(Conjugate) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Contiguous : public UnaryPrimitive { @@ -787,9 +737,6 @@ class Cos : public UnaryPrimitive { DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Cosh : public UnaryPrimitive { @@ -804,9 +751,6 @@ class Cosh : public UnaryPrimitive { DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class CustomTransforms : public Primitive { @@ -894,9 +838,6 @@ class Divide : public UnaryPrimitive { DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class DivMod : public Primitive { @@ -915,9 +856,6 @@ class DivMod : public Primitive { std::vector output_shapes(const std::vector& inputs) override { return std::vector{inputs[0].shape(), inputs[0].shape()}; } - - private: - void eval(const std::vector& inputs, std::vector& outputs); }; class Select : public UnaryPrimitive { @@ -932,9 +870,6 @@ class Select : public UnaryPrimitive { DEFINE_PRINT(Select) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Remainder : public UnaryPrimitive { @@ -949,9 +884,6 @@ class Remainder : public UnaryPrimitive { DEFINE_PRINT(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Equal : public UnaryPrimitive { @@ -979,7 +911,6 @@ class Equal : public UnaryPrimitive { }; private: - void eval(const std::vector& inputs, array& out); bool equal_nan_; }; @@ -995,9 +926,6 @@ class Erf : public UnaryPrimitive { DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ErfInv : public UnaryPrimitive { @@ -1012,9 +940,6 @@ class ErfInv : public UnaryPrimitive { DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Exp : public UnaryPrimitive { @@ -1029,9 +954,6 @@ class Exp : public UnaryPrimitive { DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Expm1 : public UnaryPrimitive { @@ -1045,9 +967,6 @@ class Expm1 : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Expm1) DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class ExpandDims : public UnaryPrimitive { @@ -1100,8 +1019,6 @@ class FFT : public UnaryPrimitive { std::vector axes_; bool inverse_; bool real_; - - void eval(const std::vector& inputs, array& out); }; class Flatten : public UnaryPrimitive { @@ -1141,9 +1058,6 @@ class Floor : public UnaryPrimitive { DEFINE_PRINT(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Full : public UnaryPrimitive { @@ -1157,9 +1071,6 @@ class Full : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Full) DEFINE_DEFAULT_IS_EQUIVALENT() - - private: - void eval(const std::vector& inputs, array& out); }; class Gather : public UnaryPrimitive { @@ -1182,7 +1093,6 @@ class Gather : public UnaryPrimitive { } private: - void eval(const std::vector& inputs, array& out); std::vector axes_; Shape slice_sizes_; }; @@ -1199,9 +1109,6 @@ class Greater : public UnaryPrimitive { DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class GreaterEqual : public UnaryPrimitive { @@ -1216,9 +1123,6 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Hadamard : public UnaryPrimitive { @@ -1241,8 +1145,6 @@ class Hadamard : public UnaryPrimitive { private: float scale_; - - void eval(const std::vector& inputs, array& out); }; class Imag : public UnaryPrimitive { @@ -1271,9 +1173,6 @@ class Less : public UnaryPrimitive { DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class LessEqual : public UnaryPrimitive { @@ -1288,9 +1187,6 @@ class LessEqual : public UnaryPrimitive { DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Load : public UnaryPrimitive { @@ -1319,7 +1215,6 @@ class Load : public UnaryPrimitive { static Stream io_stream = new_stream(Device::cpu); return io_stream; }; - void eval(const std::vector& inputs, array& out); std::shared_ptr reader_; size_t offset_; bool swap_endianness_; @@ -1360,7 +1255,6 @@ class Log : public UnaryPrimitive { private: Base base_; - void eval(const std::vector& inputs, array& out); }; class Log1p : public UnaryPrimitive { @@ -1374,9 +1268,6 @@ class Log1p : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Log1p) DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class LogicalNot : public UnaryPrimitive { @@ -1391,9 +1282,6 @@ class LogicalNot : public UnaryPrimitive { DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class LogicalAnd : public UnaryPrimitive { @@ -1408,9 +1296,6 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_PRINT(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class LogicalOr : public UnaryPrimitive { @@ -1425,9 +1310,6 @@ class LogicalOr : public UnaryPrimitive { DEFINE_PRINT(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class LogAddExp : public UnaryPrimitive { @@ -1442,9 +1324,6 @@ class LogAddExp : public UnaryPrimitive { DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Matmul : public UnaryPrimitive { @@ -1473,9 +1352,6 @@ class Maximum : public UnaryPrimitive { DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Minimum : public UnaryPrimitive { @@ -1490,9 +1366,6 @@ class Minimum : public UnaryPrimitive { DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Multiply : public UnaryPrimitive { @@ -1507,9 +1380,6 @@ class Multiply : public UnaryPrimitive { DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Negative : public UnaryPrimitive { @@ -1524,9 +1394,6 @@ class Negative : public UnaryPrimitive { DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class NotEqual : public UnaryPrimitive { @@ -1541,9 +1408,6 @@ class NotEqual : public UnaryPrimitive { DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class NumberOfElements : public UnaryPrimitive { @@ -1606,8 +1470,6 @@ class Pad : public UnaryPrimitive { std::vector axes_; Shape low_pad_size_; Shape high_pad_size_; - - void eval(const std::vector& inputs, array& out); }; class Partition : public UnaryPrimitive { @@ -1630,8 +1492,6 @@ class Partition : public UnaryPrimitive { private: int kth_; int axis_; - - void eval(const std::vector& inputs, array& out); }; class Power : public UnaryPrimitive { @@ -1646,9 +1506,6 @@ class Power : public UnaryPrimitive { DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class QuantizedMatmul : public UnaryPrimitive { @@ -1679,8 +1536,6 @@ class QuantizedMatmul : public UnaryPrimitive { int group_size_; int bits_; bool transpose_; - - void eval(const std::vector& inputs, array& out); }; class GatherQMM : public UnaryPrimitive { @@ -1706,8 +1561,6 @@ class GatherQMM : public UnaryPrimitive { int group_size_; int bits_; bool transpose_; - - void eval(const std::vector& inputs, array& out); }; class RandomBits : public UnaryPrimitive { @@ -1728,8 +1581,6 @@ class RandomBits : public UnaryPrimitive { private: Shape shape_; int width_; - - void eval(const std::vector& inputs, array& out); }; class Real : public UnaryPrimitive { @@ -1837,9 +1688,6 @@ class Round : public UnaryPrimitive { DEFINE_PRINT(Round) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Scan : public UnaryPrimitive { @@ -1936,7 +1784,6 @@ class Scatter : public UnaryPrimitive { }; private: - void eval(const std::vector& inputs, array& out); ReduceType reduce_type_; std::vector axes_; }; @@ -1953,9 +1800,6 @@ class Sigmoid : public UnaryPrimitive { DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Sign : public UnaryPrimitive { @@ -1970,9 +1814,6 @@ class Sign : public UnaryPrimitive { DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Sin : public UnaryPrimitive { @@ -1987,9 +1828,6 @@ class Sin : public UnaryPrimitive { DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Sinh : public UnaryPrimitive { @@ -2004,9 +1842,6 @@ class Sinh : public UnaryPrimitive { DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Slice : public UnaryPrimitive { @@ -2036,7 +1871,6 @@ class Slice : public UnaryPrimitive { Shape start_indices_; Shape end_indices_; Shape strides_; - void eval(const std::vector& inputs, array& out); }; @@ -2068,8 +1902,6 @@ class SliceUpdate : public UnaryPrimitive { Shape start_indices_; Shape end_indices_; Shape strides_; - - void eval(const std::vector& inputs, array& out); }; class DynamicSlice : public UnaryPrimitive { @@ -2136,7 +1968,6 @@ class Softmax : public UnaryPrimitive { }; private: - void eval(const std::vector& inputs, array& out); bool precise_; }; @@ -2159,8 +1990,6 @@ class Sort : public UnaryPrimitive { private: int axis_; - - void eval(const std::vector& inputs, array& out); }; class Split : public Primitive { @@ -2200,9 +2029,6 @@ class Square : public UnaryPrimitive { DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Sqrt : public UnaryPrimitive { @@ -2230,7 +2056,6 @@ class Sqrt : public UnaryPrimitive { } private: - void eval(const std::vector& inputs, array& out); bool recip_; }; @@ -2262,9 +2087,6 @@ class Subtract : public UnaryPrimitive { DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Squeeze : public UnaryPrimitive { @@ -2304,9 +2126,6 @@ class Tan : public UnaryPrimitive { DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Tanh : public UnaryPrimitive { @@ -2321,9 +2140,6 @@ class Tanh : public UnaryPrimitive { DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - - private: - void eval(const std::vector& inputs, array& out); }; class Unflatten : public UnaryPrimitive { @@ -2404,9 +2220,6 @@ class QRF : public Primitive { override; DEFINE_PRINT(QRF) - - private: - void eval(const std::vector& inputs, std::vector& outputs); }; /* SVD primitive. */ @@ -2421,9 +2234,6 @@ class SVD : public Primitive { DEFINE_VMAP() DEFINE_PRINT(SVD) - - private: - void eval(const std::vector& inputs, std::vector& outputs); }; /* Matrix inversion primitive. */ @@ -2442,7 +2252,6 @@ class Inverse : public UnaryPrimitive { } private: - void eval(const std::vector& inputs, array& output); bool tri_; bool upper_; }; @@ -2462,7 +2271,6 @@ class Cholesky : public UnaryPrimitive { DEFINE_PRINT(Cholesky) private: - void eval(const std::vector& inputs, array& output); bool upper_; }; @@ -2489,7 +2297,6 @@ class Eigh : public Primitive { } private: - void eval(const std::vector& inputs, std::vector& outputs); std::string uplo_; bool compute_eigenvectors_; }; diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 48bcdbff7..51101cc97 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -14,6 +14,7 @@ inline constexpr bool can_convert_to_complex128 = !std::is_same_v && std::is_convertible_v; struct complex128_t : public std::complex { + complex128_t() : std::complex() {}; complex128_t(double v, double u) : std::complex(v, u) {}; complex128_t(std::complex v) : std::complex(v) {}; @@ -32,6 +33,7 @@ inline constexpr bool can_convert_to_complex64 = !std::is_same_v && std::is_convertible_v; struct complex64_t : public std::complex { + complex64_t() : std::complex() {}; complex64_t(float v, float u) : std::complex(v, u) {}; complex64_t(std::complex v) : std::complex(v) {}; diff --git a/mlx/types/half_types.h b/mlx/types/half_types.h index 430279565..d9d6b9bf5 100644 --- a/mlx/types/half_types.h +++ b/mlx/types/half_types.h @@ -1,11 +1,12 @@ // Copyright © 2023 Apple Inc. #pragma once + #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC #include namespace mlx::core { -typedef __fp16 float16_t; +using ::float16_t; } // namespace mlx::core #else @@ -17,11 +18,12 @@ typedef struct _MLX_Float16 float16_t; } // namespace mlx::core #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + #ifdef __ARM_FEATURE_BF16 #include namespace mlx::core { -typedef __bf16 bfloat16_t; +using ::bfloat16_t; } // namespace mlx::core #else diff --git a/python/src/array.cpp b/python/src/array.cpp index ff358cbe4..58193b2ea 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -741,7 +741,7 @@ void init_array(nb::module_& m) { [](const mx::array& a) { if (mx::issubdtype(a.dtype(), mx::inexact)) { throw std::invalid_argument( - "Floating point types not allowed with or bitwise inversion."); + "Floating point types not allowed with bitwise inversion."); } if (a.dtype() != mx::bool_) { throw std::invalid_argument( @@ -791,7 +791,7 @@ void init_array(nb::module_& m) { if (mx::issubdtype(a.dtype(), mx::inexact) || mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( - "Floating point types not allowed with or bitwise or."); + "Floating point types not allowed with bitwise or."); } return mx::bitwise_or(a, b); }, @@ -806,7 +806,7 @@ void init_array(nb::module_& m) { if (mx::issubdtype(a.dtype(), mx::inexact) || mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( - "Floating point types not allowed with or bitwise or."); + "Floating point types not allowed with bitwise or."); } a.overwrite_descriptor(mx::bitwise_or(a, b)); return a; @@ -838,7 +838,7 @@ void init_array(nb::module_& m) { if (mx::issubdtype(a.dtype(), mx::inexact) || mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( - "Floating point types not allowed with or left shift."); + "Floating point types not allowed with left shift."); } a.overwrite_descriptor(mx::left_shift(a, b)); return a; @@ -870,7 +870,7 @@ void init_array(nb::module_& m) { if (mx::issubdtype(a.dtype(), mx::inexact) || mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( - "Floating point types not allowed with or right shift."); + "Floating point types not allowed with right shift."); } a.overwrite_descriptor(mx::right_shift(a, b)); return a; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 899899c4e..97b8afa47 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -289,6 +289,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1]) + z = -mx.ones(64) % mx.full(64, 2) + self.assertTrue(mx.array_equal(z, mx.ones(64))) + def test_comparisons(self): a = mx.array([0.0, 1.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0]) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 363722bcf..160eb6400 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -207,8 +207,8 @@ class TestQuantized(mlx_tests.MLXTestCase): with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): x_shape = (1, N) if B == 0 else (B, 1, N) w_shape = (N, M) if B == 0 else (B, N, M) - x = mx.random.normal(shape=x_shape, key=k1) - w = mx.random.normal(shape=w_shape, key=k2) + x = 1e-1 * mx.random.normal(shape=x_shape, key=k1) + w = 1e-1 * mx.random.normal(shape=w_shape, key=k2) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul(