mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor common into cpu specific and truly common (#1817)
* refactor * fix extension example * fix no-cpu
This commit is contained in:
56
mlx/backend/cpu/simd/accelerate_fp16_simd.h
Normal file
56
mlx/backend/cpu/simd/accelerate_fp16_simd.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION < 6
|
||||
#include "mlx/backend/cpu/simd/neon_fp16_simd.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION >= 6
|
||||
constexpr int N = 8;
|
||||
template <int N>
|
||||
struct ScalarT<float16_t, N> {
|
||||
using v = _Float16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
|
||||
Simd<float, N> in = v; \
|
||||
return op(in); \
|
||||
}
|
||||
|
||||
SIMD_FP16_DEFAULT_UNARY(acos)
|
||||
SIMD_FP16_DEFAULT_UNARY(acosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(asin)
|
||||
SIMD_FP16_DEFAULT_UNARY(asinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(atan)
|
||||
SIMD_FP16_DEFAULT_UNARY(atanh)
|
||||
SIMD_FP16_DEFAULT_UNARY(cosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(expm1)
|
||||
SIMD_FP16_DEFAULT_UNARY(log)
|
||||
SIMD_FP16_DEFAULT_UNARY(log2)
|
||||
SIMD_FP16_DEFAULT_UNARY(log10)
|
||||
SIMD_FP16_DEFAULT_UNARY(log1p)
|
||||
SIMD_FP16_DEFAULT_UNARY(sinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(tan)
|
||||
SIMD_FP16_DEFAULT_UNARY(tanh)
|
||||
|
||||
#define SIMD_FP16_DEFAULT_BINARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
|
||||
Simd<float, N> a = x; \
|
||||
Simd<float, N> b = y; \
|
||||
return op(a, b); \
|
||||
}
|
||||
SIMD_FP16_DEFAULT_BINARY(atan2)
|
||||
SIMD_FP16_DEFAULT_BINARY(remainder)
|
||||
SIMD_FP16_DEFAULT_BINARY(pow)
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
303
mlx/backend/cpu/simd/accelerate_simd.h
Normal file
303
mlx/backend/cpu/simd/accelerate_simd.h
Normal file
@@ -0,0 +1,303 @@
|
||||
#pragma once
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
|
||||
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__TV_OS_VERSION_MIN_REQUIRED >= 180000
|
||||
#define MLX_SIMD_LIBRARY_VERSION 6
|
||||
#else
|
||||
#define MLX_SIMD_LIBRARY_VERSION 5
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
// Apple simd namespace
|
||||
namespace asd = ::simd;
|
||||
|
||||
// This indirection is needed to remap certain types to ones that accelerate
|
||||
// SIMD can handle
|
||||
template <typename T, int N>
|
||||
struct ScalarT {
|
||||
using v = T;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<bool, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int8_t, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<uint64_t, N> {
|
||||
using v = unsigned long;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int64_t, N> {
|
||||
using v = long;
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct Simd {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = typename ScalarT<T, N>::v;
|
||||
|
||||
Simd<T, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(U v) : value(v){};
|
||||
|
||||
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
|
||||
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
|
||||
x.value, y.value);
|
||||
};
|
||||
|
||||
T operator[](int idx) const {
|
||||
return reinterpret_cast<const T*>(&value)[idx];
|
||||
}
|
||||
|
||||
T& operator[](int idx) {
|
||||
return reinterpret_cast<T*>(&value)[idx];
|
||||
}
|
||||
|
||||
typename asd::Vector<scalar_t, N>::packed_t value;
|
||||
};
|
||||
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
Simd<T, N> name(Simd<T, N> v) { \
|
||||
return op(v.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_UNARY(abs, asd::abs)
|
||||
SIMD_DEFAULT_UNARY(floor, asd::floor)
|
||||
SIMD_DEFAULT_UNARY(acos, asd::acos)
|
||||
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
|
||||
SIMD_DEFAULT_UNARY(asin, asd::asin)
|
||||
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
|
||||
SIMD_DEFAULT_UNARY(atan, asd::atan)
|
||||
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
|
||||
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
|
||||
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
|
||||
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
|
||||
SIMD_DEFAULT_UNARY(log, asd::log)
|
||||
SIMD_DEFAULT_UNARY(log2, asd::log2)
|
||||
SIMD_DEFAULT_UNARY(log10, asd::log10)
|
||||
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
|
||||
SIMD_DEFAULT_UNARY(rint, asd::rint)
|
||||
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
|
||||
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
|
||||
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
|
||||
SIMD_DEFAULT_UNARY(recip, asd::recip)
|
||||
SIMD_DEFAULT_UNARY(tan, asd::tan)
|
||||
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> operator-(Simd<T, N> v) {
|
||||
return -v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> isnan(Simd<T, N> v) {
|
||||
return asd::convert<char>(v.value != v.value);
|
||||
}
|
||||
|
||||
// No simd_boolN in accelerate, use int8_t instead
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> operator!(Simd<T, N> v) {
|
||||
return asd::convert<char>(!v.value);
|
||||
}
|
||||
|
||||
#define SIMD_DEFAULT_BINARY(OP) \
|
||||
template <typename T, typename U, int N> \
|
||||
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
|
||||
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_BINARY(+)
|
||||
SIMD_DEFAULT_BINARY(-)
|
||||
SIMD_DEFAULT_BINARY(/)
|
||||
SIMD_DEFAULT_BINARY(*)
|
||||
SIMD_DEFAULT_BINARY(<<)
|
||||
SIMD_DEFAULT_BINARY(>>)
|
||||
SIMD_DEFAULT_BINARY(|)
|
||||
SIMD_DEFAULT_BINARY(^)
|
||||
SIMD_DEFAULT_BINARY(&)
|
||||
SIMD_DEFAULT_BINARY(&&)
|
||||
SIMD_DEFAULT_BINARY(||)
|
||||
|
||||
#define SIMD_DEFAULT_COMPARISONS(OP) \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
|
||||
return asd::convert<char>(a.value OP b); \
|
||||
} \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
|
||||
return asd::convert<char>(a OP b.value); \
|
||||
} \
|
||||
template <int N, typename T1, typename T2> \
|
||||
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
|
||||
return asd::convert<char>(a.value OP b.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_COMPARISONS(>)
|
||||
SIMD_DEFAULT_COMPARISONS(<)
|
||||
SIMD_DEFAULT_COMPARISONS(>=)
|
||||
SIMD_DEFAULT_COMPARISONS(<=)
|
||||
SIMD_DEFAULT_COMPARISONS(==)
|
||||
SIMD_DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
return asd::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::max(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::min(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
Simd<T, N> r;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
r = asd::remainder(a.value, b.value);
|
||||
} else {
|
||||
r = a - b * (a / b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
auto mask = r != 0 && (r < 0 != b < 0);
|
||||
r = select(mask, r + b, r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 4) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
|
||||
} else {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
|
||||
return asd::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N>
|
||||
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
|
||||
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
|
||||
template <typename T, int N>
|
||||
bool all(Simd<T, N> x) {
|
||||
return asd::all(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
bool any(Simd<T, N> x) {
|
||||
return asd::any(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T sum(Simd<T, N> x) {
|
||||
return asd::reduce_add(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T max(Simd<T, N> x) {
|
||||
return asd::reduce_max(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T min(Simd<T, N> x) {
|
||||
return asd::reduce_min(x.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
T prod(Simd<T, N> x) {
|
||||
auto ptr = (T*)&x;
|
||||
auto lhs = load<T, N / 2>(ptr);
|
||||
auto rhs = load<T, N / 2>(ptr + N / 2);
|
||||
return prod(lhs * rhs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h"
|
||||
#endif
|
||||
253
mlx/backend/cpu/simd/base_simd.h
Normal file
253
mlx/backend/cpu/simd/base_simd.h
Normal file
@@ -0,0 +1,253 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
namespace mlx::core::simd {
|
||||
template <typename T, int N>
|
||||
struct Simd;
|
||||
|
||||
template <typename T>
|
||||
static constexpr int max_size = 1;
|
||||
|
||||
template <typename T>
|
||||
struct Simd<T, 1> {
|
||||
static constexpr int size = 1;
|
||||
T value;
|
||||
Simd() {}
|
||||
template <typename U>
|
||||
Simd(Simd<U, 1> v) : value(v.value) {}
|
||||
template <typename U>
|
||||
Simd(U v) : value(v) {}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> load(const T* x) {
|
||||
return *(Simd<T, N>*)x;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
void store(T* dst, Simd<T, N> x) {
|
||||
// Maintain invariant that bool is either 0 or 1 as
|
||||
// simd comparison ops set all bits in the result to 1
|
||||
if constexpr (std::is_same_v<T, bool> && N > 1) {
|
||||
x = x & 1;
|
||||
}
|
||||
*(Simd<T, N>*)dst = x;
|
||||
}
|
||||
|
||||
template <typename, typename = void>
|
||||
constexpr bool is_complex = false;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
|
||||
true;
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rint(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{
|
||||
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::rint(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rsqrt(Simd<T, 1> in) {
|
||||
return T(1.0) / sqrt(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> recip(Simd<T, 1> in) {
|
||||
return T(1.0) / in;
|
||||
}
|
||||
|
||||
#define DEFAULT_UNARY(name, op) \
|
||||
template <typename T> \
|
||||
Simd<T, 1> name(Simd<T, 1> in) { \
|
||||
return op(in.value); \
|
||||
}
|
||||
|
||||
DEFAULT_UNARY(operator-, std::negate{})
|
||||
DEFAULT_UNARY(operator!, std::logical_not{})
|
||||
DEFAULT_UNARY(abs, std::abs)
|
||||
DEFAULT_UNARY(acos, std::acos)
|
||||
DEFAULT_UNARY(acosh, std::acosh)
|
||||
DEFAULT_UNARY(asin, std::asin)
|
||||
DEFAULT_UNARY(asinh, std::asinh)
|
||||
DEFAULT_UNARY(atan, std::atan)
|
||||
DEFAULT_UNARY(atanh, std::atanh)
|
||||
DEFAULT_UNARY(ceil, std::ceil)
|
||||
DEFAULT_UNARY(conj, std::conj)
|
||||
DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
|
||||
return std::real(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
|
||||
return std::imag(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, 1> isnan(Simd<T, 1> in) {
|
||||
return std::isnan(in.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_BINARY(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
|
||||
->Simd<decltype(a.value OP b.value), 1> { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_BINARY(+)
|
||||
DEFAULT_BINARY(-)
|
||||
DEFAULT_BINARY(*)
|
||||
DEFAULT_BINARY(/)
|
||||
DEFAULT_BINARY(<<)
|
||||
DEFAULT_BINARY(>>)
|
||||
DEFAULT_BINARY(|)
|
||||
DEFAULT_BINARY(^)
|
||||
DEFAULT_BINARY(&)
|
||||
DEFAULT_BINARY(&&)
|
||||
DEFAULT_BINARY(||)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
T r;
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
r = a % b;
|
||||
} else {
|
||||
r = std::remainder(a, b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
if (r != 0 && (r < 0 != b < 0)) {
|
||||
r += b;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
T base = a.value;
|
||||
T exp = b.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return std::pow(base, exp);
|
||||
} else {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
return std::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_COMPARISONS(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_COMPARISONS(>)
|
||||
DEFAULT_COMPARISONS(<)
|
||||
DEFAULT_COMPARISONS(>=)
|
||||
DEFAULT_COMPARISONS(<=)
|
||||
DEFAULT_COMPARISONS(==)
|
||||
DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename MaskT, typename T>
|
||||
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
|
||||
return mask.value ? x.value : y.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
|
||||
return std::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
|
||||
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
#define DEFAULT_REDUCTION(name, type) \
|
||||
template <typename T> \
|
||||
type name(Simd<T, 1> x) { \
|
||||
return x.value; \
|
||||
}
|
||||
|
||||
DEFAULT_REDUCTION(max, T)
|
||||
DEFAULT_REDUCTION(min, T)
|
||||
DEFAULT_REDUCTION(sum, T)
|
||||
DEFAULT_REDUCTION(prod, T)
|
||||
DEFAULT_REDUCTION(any, bool)
|
||||
DEFAULT_REDUCTION(all, bool)
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
193
mlx/backend/cpu/simd/math.h
Normal file
193
mlx/backend/cpu/simd/math.h
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/type.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
template <typename T, int N>
|
||||
Simd<T, N> exp(Simd<T, N> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{std::exp(in.value)};
|
||||
} else {
|
||||
Simd<float, N> x_init = in;
|
||||
auto x = x_init * 1.442695f; // multiply with log_2(e)
|
||||
Simd<float, N> ipart, fpart;
|
||||
ipart = floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = fma(x, fpart, 1.339887440266574e-3f);
|
||||
x = fma(x, fpart, 9.618437357674640e-3f);
|
||||
x = fma(x, fpart, 5.550332471162809e-2f);
|
||||
x = fma(x, fpart, 2.402264791363012e-1f);
|
||||
x = fma(x, fpart, 6.931472028550421e-1f);
|
||||
x = fma(x, fpart, 1.000000000000000f);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
|
||||
|
||||
// Deal with NaN and Inf
|
||||
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
|
||||
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
|
||||
result = select(x_init < -88.0f, Simd<float, N>(0), result);
|
||||
return Simd<T, N>(result);
|
||||
}
|
||||
}
|
||||
|
||||
/* Implementation from:
|
||||
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
|
||||
* which originally came from the Cephes math library.
|
||||
*/
|
||||
template <bool Sine, typename T, int N>
|
||||
Simd<T, N> sincos(Simd<T, N> in) {
|
||||
auto sign_mask_sin = in < 0;
|
||||
in = abs(in);
|
||||
Simd<float, N> x = in;
|
||||
|
||||
// scale by 4/Pi
|
||||
auto y = x * 1.27323954473516f;
|
||||
|
||||
// store the integer part of y in mm0
|
||||
Simd<uint32_t, N> emm2 = y;
|
||||
|
||||
// j=(j+1) & (~1) (see the cephes sources)
|
||||
emm2 = emm2 + 1;
|
||||
emm2 = emm2 & ~1;
|
||||
|
||||
y = emm2;
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
x = fma(y, Simd<float, N>(-0.78515625f), x);
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
auto z = x * x;
|
||||
|
||||
auto y1 =
|
||||
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
|
||||
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
|
||||
y1 = fma(y1, z, 4.166664568298827e-2f);
|
||||
y2 = fma(y2, z, -1.6666654611e-1f);
|
||||
y1 = y1 * z;
|
||||
y2 = y2 * z;
|
||||
y1 = y1 * z;
|
||||
y2 = fma(x, y2, x);
|
||||
y1 = fma(z, Simd<float, N>(-0.5f), y1);
|
||||
y1 = y1 + 1.0f;
|
||||
|
||||
if constexpr (Sine) {
|
||||
auto ys = select(poly_mask, y1, y2);
|
||||
return select(sign_mask_sin, -ys, ys);
|
||||
} else {
|
||||
auto yc = select(poly_mask, y2, y1);
|
||||
return select(sign_mask_cos, yc, -yc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> sin(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::sin(x.value);
|
||||
} else {
|
||||
return sincos<true>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> cos(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::cos(x.value);
|
||||
} else {
|
||||
return sincos<false>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erf(Simd<T, N> x) {
|
||||
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
|
||||
Simd<float, N> v = x;
|
||||
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
|
||||
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
|
||||
r = fma(r, t, 1.421413741f);
|
||||
r = fma(r, t, -0.284496736f);
|
||||
r = fma(r, t, 0.254829592f);
|
||||
auto e = -exp(-v * v);
|
||||
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
|
||||
return select(x > 0, result, -result);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erfinv(Simd<T, N> a_) {
|
||||
Simd<float, N> a = a_;
|
||||
auto t = fma(a, 0.0f - a, 1.0f);
|
||||
t = log(t);
|
||||
auto lhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
};
|
||||
auto rhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
};
|
||||
auto thresh = 6.125f;
|
||||
// Compute both branches and select if N > 1
|
||||
if constexpr (N == 1) {
|
||||
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
|
||||
return a * lhs(t);
|
||||
} else { // maximum ulp error = 2.35002
|
||||
return a * rhs(t);
|
||||
}
|
||||
} else {
|
||||
return a * select(t > thresh, lhs(t), rhs(t));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
212
mlx/backend/cpu/simd/neon_fp16_simd.h
Normal file
212
mlx/backend/cpu/simd/neon_fp16_simd.h
Normal file
@@ -0,0 +1,212 @@
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr int N = 8;
|
||||
|
||||
template <>
|
||||
struct Simd<float16_t, N> {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = float16_t;
|
||||
|
||||
Simd<float16_t, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
|
||||
|
||||
Simd<float16_t, N>(float16x8_t v) : value(v){};
|
||||
|
||||
Simd<float16_t, N>(Simd<float, N> other) {
|
||||
auto f32x4_a = *(float32x4_t*)(&other);
|
||||
auto f32x4_b = *((float32x4_t*)(&other) + 1);
|
||||
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
|
||||
};
|
||||
|
||||
Simd<float16_t, N>(Simd<uint16_t, N> other) {
|
||||
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
|
||||
};
|
||||
|
||||
operator Simd<int16_t, N>() {
|
||||
auto v = vcvtq_s16_f16(value);
|
||||
return load<int16_t, N>((int16_t*)&v);
|
||||
};
|
||||
|
||||
operator Simd<float, N>() {
|
||||
float32x4x2_t v;
|
||||
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
|
||||
v.val[1] = vcvt_high_f32_f16(value);
|
||||
return load<float, N>((float*)&v);
|
||||
}
|
||||
float16_t operator[](int idx) const {
|
||||
return reinterpret_cast<const float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16_t& operator[](int idx) {
|
||||
return reinterpret_cast<float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16x8_t value;
|
||||
};
|
||||
|
||||
#define DEFINE_NEON_UNARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
|
||||
return Simd<float16_t, N>{op(a.value)}; \
|
||||
}
|
||||
|
||||
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
|
||||
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
|
||||
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
|
||||
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
|
||||
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
|
||||
|
||||
#define DEFINE_NEON_BINARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
return op(a.value, b.value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
|
||||
return op(a.value, Simd<float16_t, N>(b).value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
|
||||
return op(Simd<float16_t, N>(a).value, b.value); \
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
|
||||
auto out = vceqzq_f16(v.value);
|
||||
return Simd<uint16_t, N>(*(uint16_t*)&out);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
|
||||
return vnegq_f16(v.value);
|
||||
}
|
||||
|
||||
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
|
||||
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
|
||||
|
||||
#define DEFINE_NEON_COMPARISON(Op, op) \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
|
||||
auto out = op(a.value, Simd<float16_t, N>(b).value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
|
||||
auto out = op(Simd<float16_t, N>(a).value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
inline Simd<bool, N> operator Op( \
|
||||
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
auto out = op(a.value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
}
|
||||
|
||||
DEFINE_NEON_COMPARISON(==, vceqq_f16)
|
||||
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
|
||||
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
|
||||
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
|
||||
DEFINE_NEON_COMPARISON(<, vcltq_f16)
|
||||
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
|
||||
return !(a == b);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator||(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
inline Simd<float16_t, N> operator&&(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
|
||||
return v != v;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<float16_t, N>
|
||||
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
|
||||
return minimum(maximum(v, min), max);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
|
||||
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
|
||||
}
|
||||
|
||||
template <typename MaskT>
|
||||
Simd<float16_t, N>
|
||||
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
|
||||
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
inline float16_t max(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t min(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmin_f16(y, y);
|
||||
y = vpmin_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t sum(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpadd_f16(y, y);
|
||||
y = vpadd_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t prod(Simd<float16_t, N> x) {
|
||||
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
auto out = hx[0];
|
||||
hx[0] *= hx[1];
|
||||
hx[0] *= hx[2];
|
||||
hx[0] *= hx[3];
|
||||
return hx[0];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
4
mlx/backend/cpu/simd/simd.h
Normal file
4
mlx/backend/cpu/simd/simd.h
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/math.h"
|
||||
#include "mlx/backend/cpu/simd/type.h"
|
||||
7
mlx/backend/cpu/simd/type.h
Normal file
7
mlx/backend/cpu/simd/type.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||
#endif
|
||||
Reference in New Issue
Block a user