mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-01 16:26:49 +08:00
Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios * Mark float literals in softmax.cpp to be float16_t for errors in ios * Add arm neon vector operation guards * Redirect to common backend for consistency
This commit is contained in:
parent
56c8a33439
commit
8c2e15e6c8
@ -1,9 +1,9 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
#include <vecLib/vDSP.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include <vecLib/BNNS/bnns.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#include <vecLib/cblas_new.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/accelerate/utils.h"
|
#include "mlx/backend/accelerate/utils.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include <vecLib/vDSP.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#include <vecLib/vForce.h>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
#include <vecLib/vDSP.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -3,7 +3,10 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <simd/math.h>
|
#include <simd/math.h>
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
|
|
||||||
@ -53,25 +56,26 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
|||||||
return (*(simd_float16*)&epart) * x;
|
return (*(simd_float16*)&epart) * x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
/**
|
/**
|
||||||
* The ARM neon equivalent of the fast exp above.
|
* The ARM neon equivalent of the fast exp above.
|
||||||
*/
|
*/
|
||||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||||
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||||
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||||
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||||
|
|
||||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||||
|
|
||||||
x = vdupq_n_f16(1.535336188319500e-4f);
|
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||||
|
|
||||||
// generate 2**ipart in the floating point representation using integer
|
// generate 2**ipart in the floating point representation using integer
|
||||||
// bitshifting
|
// bitshifting
|
||||||
@ -107,53 +111,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
|
|||||||
return vget_lane_f16(y, 0);
|
return vget_lane_f16(y, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct AccelerateSimdOps {
|
|
||||||
VT init(T a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT load(const T* a) {
|
|
||||||
return *(VT*)a;
|
|
||||||
}
|
|
||||||
|
|
||||||
void store(T* dst, VT x) {
|
|
||||||
*(VT*)dst = x;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT max(VT a, VT b) {
|
|
||||||
return simd_max(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT exp(VT x) {
|
|
||||||
return simd_fast_exp(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT add(VT a, VT b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT sub(VT a, T b) {
|
|
||||||
return a - b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, VT b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, T b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_max(VT x) {
|
|
||||||
return simd_reduce_max(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_add(VT x) {
|
|
||||||
return simd_reduce_add(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
template <typename T, typename VT>
|
||||||
struct NeonFp16SimdOps {
|
struct NeonFp16SimdOps {
|
||||||
VT init(T a) {
|
VT init(T a) {
|
||||||
@ -201,6 +158,55 @@ struct NeonFp16SimdOps {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
template <typename T, typename VT>
|
||||||
|
struct AccelerateSimdOps {
|
||||||
|
VT init(T a) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT load(const T* a) {
|
||||||
|
return *(VT*)a;
|
||||||
|
}
|
||||||
|
|
||||||
|
void store(T* dst, VT x) {
|
||||||
|
*(VT*)dst = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT max(VT a, VT b) {
|
||||||
|
return simd_max(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT exp(VT x) {
|
||||||
|
return simd_fast_exp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT add(VT a, VT b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT sub(VT a, T b) {
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, VT b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_max(VT x) {
|
||||||
|
return simd_reduce_max(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_add(VT x) {
|
||||||
|
return simd_reduce_add(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||||
void softmax(const array& in, array& out) {
|
void softmax(const array& in, array& out) {
|
||||||
Ops ops;
|
Ops ops;
|
||||||
@ -362,12 +368,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
AccelerateSimdOps<float, simd_float16>,
|
AccelerateSimdOps<float, simd_float16>,
|
||||||
16>(in, out);
|
16>(in, out);
|
||||||
} else {
|
} else {
|
||||||
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
softmax<
|
softmax<
|
||||||
float16_t,
|
float16_t,
|
||||||
float16_t,
|
float16_t,
|
||||||
float16x8_t,
|
float16x8_t,
|
||||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||||
8>(in, out);
|
8>(in, out);
|
||||||
|
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
eval(inputs, out); // Redirect to common backend for consistency
|
||||||
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vecLib/BNNS/bnns.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
Loading…
Reference in New Issue
Block a user