From 8c2e15e6c8e654d1b057b35115255ecacb938399 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 26 Jun 2024 09:01:50 -0700 Subject: [PATCH] Accelerate import updates for iOS (#1227) * Update veclib and bnns includes to #include for compatibility with ios * Mark float literals in softmax.cpp to be float16_t for errors in ios * Add arm neon vector operation guards * Redirect to common backend for consistency --- mlx/backend/accelerate/conv.cpp | 4 +- mlx/backend/accelerate/matmul.cpp | 3 +- mlx/backend/accelerate/primitives.cpp | 3 +- mlx/backend/accelerate/reduce.cpp | 2 +- mlx/backend/accelerate/softmax.cpp | 128 ++++++++++++++------------ mlx/backend/accelerate/utils.h | 4 +- 6 files changed, 76 insertions(+), 68 deletions(-) diff --git a/mlx/backend/accelerate/conv.cpp b/mlx/backend/accelerate/conv.cpp index 22ddd907d..026813aa2 100644 --- a/mlx/backend/accelerate/conv.cpp +++ b/mlx/backend/accelerate/conv.cpp @@ -1,9 +1,9 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include +#include #include -#include #include "mlx/backend/common/copy.h" #include "mlx/primitives.h" diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp index 6113223a4..78ce66e7a 100644 --- a/mlx/backend/accelerate/matmul.cpp +++ b/mlx/backend/accelerate/matmul.cpp @@ -2,8 +2,7 @@ #include -#include -#include +#include #include "mlx/backend/accelerate/utils.h" #include "mlx/backend/common/copy.h" diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index ad778cbc7..8c3615599 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -3,8 +3,7 @@ #include #include -#include -#include +#include #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" diff --git a/mlx/backend/accelerate/reduce.cpp b/mlx/backend/accelerate/reduce.cpp index 15a5d83b9..287243943 100644 --- a/mlx/backend/accelerate/reduce.cpp +++ b/mlx/backend/accelerate/reduce.cpp @@ -2,8 +2,8 @@ #include +#include #include -#include #include "mlx/backend/common/reduce.h" #include "mlx/primitives.h" diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 4d74ff683..91d9fe56a 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -3,7 +3,10 @@ #include #include +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include +#endif + #include #include @@ -53,25 +56,26 @@ inline simd_float16 simd_fast_exp(simd_float16 x) { return (*(simd_float16*)&epart) * x; } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** * The ARM neon equivalent of the fast exp above. */ inline float16x8_t neon_fast_exp(float16x8_t x) { - x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e) - x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14 - x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14 + x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e) + x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14 + x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14 - float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5))); + float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f)))); float16x8_t fpart = vsubq_f16(x, ipart); - x = vdupq_n_f16(1.535336188319500e-4f); - x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart); - x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart); + x = vdupq_n_f16(float16_t(1.535336188319500e-4f)); + x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart); + x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart); // generate 2**ipart in the floating point representation using integer // bitshifting @@ -107,53 +111,6 @@ inline float16_t neon_reduce_add(float16x8_t x) { return vget_lane_f16(y, 0); } -template -struct AccelerateSimdOps { - VT init(T a) { - return a; - } - - VT load(const T* a) { - return *(VT*)a; - } - - void store(T* dst, VT x) { - *(VT*)dst = x; - } - - VT max(VT a, VT b) { - return simd_max(a, b); - } - - VT exp(VT x) { - return simd_fast_exp(x); - } - - VT add(VT a, VT b) { - return a + b; - } - - VT sub(VT a, T b) { - return a - b; - } - - VT mul(VT a, VT b) { - return a * b; - } - - VT mul(VT a, T b) { - return a * b; - } - - T reduce_max(VT x) { - return simd_reduce_max(x); - } - - T reduce_add(VT x) { - return simd_reduce_add(x); - } -}; - template struct NeonFp16SimdOps { VT init(T a) { @@ -201,6 +158,55 @@ struct NeonFp16SimdOps { } }; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template +struct AccelerateSimdOps { + VT init(T a) { + return a; + } + + VT load(const T* a) { + return *(VT*)a; + } + + void store(T* dst, VT x) { + *(VT*)dst = x; + } + + VT max(VT a, VT b) { + return simd_max(a, b); + } + + VT exp(VT x) { + return simd_fast_exp(x); + } + + VT add(VT a, VT b) { + return a + b; + } + + VT sub(VT a, T b) { + return a - b; + } + + VT mul(VT a, VT b) { + return a * b; + } + + VT mul(VT a, T b) { + return a * b; + } + + T reduce_max(VT x) { + return simd_reduce_max(x); + } + + T reduce_add(VT x) { + return simd_reduce_add(x); + } +}; + template void softmax(const array& in, array& out) { Ops ops; @@ -362,12 +368,16 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { AccelerateSimdOps, 16>(in, out); } else { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC softmax< float16_t, float16_t, float16x8_t, NeonFp16SimdOps, 8>(in, out); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + eval(inputs, out); // Redirect to common backend for consistency +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC } break; case bfloat16: diff --git a/mlx/backend/accelerate/utils.h b/mlx/backend/accelerate/utils.h index 1dbabe304..389099f37 100644 --- a/mlx/backend/accelerate/utils.h +++ b/mlx/backend/accelerate/utils.h @@ -1,8 +1,8 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once -#include +#include #include "mlx/dtype.h" namespace mlx::core {