Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)

* start to cleanup/unify accelerate and common back-ends

* more progress

* simplify

* add half type and allow infs in simd exp

* unify softmax + quantized, more dispatches to simd quantized mm

* add sin/cos, use simd in vector-scalar ops

* faster CPU vectorize quant

* faster erf/erfinv
This commit is contained in:
Awni Hannun 2025-01-29 14:34:49 -08:00 committed by GitHub
parent 7064fed1b1
commit 4758c8baa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 1920 additions and 2640 deletions

View File

@ -147,6 +147,7 @@ if(MLX_BUILD_CPU)
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK) add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE) elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code. # Download and build OpenBLAS from source code.

View File

@ -3,6 +3,4 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)

View File

@ -11,448 +11,8 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core { namespace mlx::core {
// Use the default implementation for the following primitives
DEFAULT(Arange)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().contiguous) {
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
}
eval(inputs, out);
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
if (in.data_size() == 1 && out.dtype() == float32) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
} else {
eval(inputs, out);
}
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
switch (base_) {
case Base::e:
vvlogf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::two:
vvlog2f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::ten:
vvlog10f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
}
} else {
eval(inputs, out);
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x * y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) { void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@ -484,120 +44,4 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
eval(inputs, out);
}
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
vvsqrtf(out.data<float>(), in.data<float>(), &size);
}
} else {
eval(inputs, out);
}
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
float minus_1 = -1;
vDSP_vsmsa(
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
float val = -(*s);
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
int val = -(*s);
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
},
UseDefaultBinaryOp());
} else {
eval(inputs, out);
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,117 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <simd/vector.h>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void _qmm_t_4_64(
float* result,
const float* x,
const uint32_t* w,
const float* scales,
const float* biases,
int M,
int N,
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
bool condition =
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
B,
batched_w);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@ -1,393 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
ipart = simd::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
int16x8_t epart = vcvtq_s16_f16(ipart);
epart = vaddq_s16(epart, vdupq_n_s16(15));
epart = vshlq_n_s16(epart, 10);
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
}
/**
* Implementation of folding maximum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_max(float16x8_t x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
/**
* Implementation of folding sum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_add(float16x8_t x) {
float16x4_t y;
float16x4_t zero = vdup_n_f16(0);
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
y = vpadd_f16(y, zero);
y = vpadd_f16(y, zero);
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
VT vnormalizer = ops.init(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:
eval(inputs, out);
break;
case complex64:
eval(inputs, out);
break;
}
}
} // namespace mlx::core

View File

@ -5,6 +5,18 @@ else()
set(COMPILER ${CMAKE_CXX_COMPILER}) set(COMPILER ${CMAKE_CXX_COMPILER})
endif() endif()
set(COMPILE_DEPS
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
simd/simd.h
simd/base_simd.h
simd/math.h
simd/type.h
unary_ops.h
binary_ops.h)
if(MSVC) if(MSVC)
set(SHELL_EXT ps1) set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
@ -19,13 +31,8 @@ add_custom_command(
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT} DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
compiled_preamble.h ${COMPILE_DEPS})
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
@ -60,6 +67,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)

View File

@ -61,7 +61,7 @@ void arg_reduce_dispatch(
} // namespace } // namespace
void ArgReduce::eval(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));

View File

@ -6,8 +6,8 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_ops.h"
#include "mlx/backend/common/binary_two.h" #include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -15,69 +15,61 @@ namespace mlx::core {
namespace { namespace {
template <typename T, typename U, typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, U, Op> opsv(op);
DefaultVectorScalar<T, U, Op> opvs(op);
DefaultVectorVector<T, U, Op> opvv(op);
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
}
template <typename Op> template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) { void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case bool_:
comparison_op<bool, bool>(a, b, out, op); binary_op<bool, bool>(a, b, out, op);
break; break;
case uint8: case uint8:
comparison_op<uint8_t, bool>(a, b, out, op); binary_op<uint8_t, bool>(a, b, out, op);
break; break;
case uint16: case uint16:
comparison_op<uint16_t, bool>(a, b, out, op); binary_op<uint16_t, bool>(a, b, out, op);
break; break;
case uint32: case uint32:
comparison_op<uint32_t, bool>(a, b, out, op); binary_op<uint32_t, bool>(a, b, out, op);
break; break;
case uint64: case uint64:
comparison_op<uint64_t, bool>(a, b, out, op); binary_op<uint64_t, bool>(a, b, out, op);
break; break;
case int8: case int8:
comparison_op<int8_t, bool>(a, b, out, op); binary_op<int8_t, bool>(a, b, out, op);
break; break;
case int16: case int16:
comparison_op<int16_t, bool>(a, b, out, op); binary_op<int16_t, bool>(a, b, out, op);
break; break;
case int32: case int32:
comparison_op<int32_t, bool>(a, b, out, op); binary_op<int32_t, bool>(a, b, out, op);
break; break;
case int64: case int64:
comparison_op<int64_t, bool>(a, b, out, op); binary_op<int64_t, bool>(a, b, out, op);
break; break;
case float16: case float16:
comparison_op<float16_t, bool>(a, b, out, op); binary_op<float16_t, bool>(a, b, out, op);
break; break;
case float32: case float32:
comparison_op<float, bool>(a, b, out, op); binary_op<float, bool>(a, b, out, op);
break; break;
case bfloat16: case bfloat16:
comparison_op<bfloat16_t, bool>(a, b, out, op); binary_op<bfloat16_t, bool>(a, b, out, op);
break; break;
case complex64: case complex64:
comparison_op<complex64_t, bool>(a, b, out, op); binary_op<complex64_t, bool>(a, b, out, op);
break; break;
} }
} }
} // namespace } // namespace
void Add::eval(const std::vector<array>& inputs, array& out) { void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Add()); binary(a, b, out, detail::Add());
} }
void DivMod::eval( void DivMod::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
@ -132,50 +124,68 @@ void DivMod::eval(
} }
} }
void Divide::eval(const std::vector<array>& inputs, array& out) { void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Divide()); binary(a, b, out, detail::Divide());
} }
void Remainder::eval(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Remainder()); binary(a, b, out, detail::Remainder());
} }
void Equal::eval(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) { if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
} else { } else {
comparison_op(inputs[0], inputs[1], out, detail::Equal()); comparison_op(a, b, out, detail::Equal());
} }
} }
void Greater::eval(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater()); comparison_op(inputs[0], inputs[1], out, detail::Greater());
} }
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
} }
void Less::eval(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less()); comparison_op(inputs[0], inputs[1], out, detail::Less());
} }
void LessEqual::eval(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
} }
void LogAddExp::eval(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -196,54 +206,54 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd()); binary(in1, in2, out, detail::LogicalAnd());
} }
void LogicalOr::eval(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr()); binary(in1, in2, out, detail::LogicalOr());
} }
void Maximum::eval(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Maximum()); binary(a, b, out, detail::Maximum());
} }
void Minimum::eval(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Minimum()); binary(a, b, out, detail::Minimum());
} }
void Multiply::eval(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Multiply()); binary(a, b, out, detail::Multiply());
} }
void NotEqual::eval(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
} }
void Power::eval(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Power()); binary(a, b, out, detail::Power());
} }
void Subtract::eval(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@ -307,7 +317,7 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
void ArcTan2::eval(const std::vector<array>& inputs, array& out) { void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];

View File

@ -7,6 +7,8 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
@ -122,16 +124,22 @@ void set_binary_op_output_data(
} }
} }
struct UseDefaultBinaryOp {}; template <typename Op>
struct VectorScalar {
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
Op op; Op op;
DefaultVectorScalar(Op op_) : op(op_) {} VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b; T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) { while (size-- > 0) {
*dst = op(*a, scalar); *dst = op(*a, scalar);
dst++; dst++;
@ -140,14 +148,22 @@ struct DefaultVectorScalar {
} }
}; };
template <typename T, typename U, typename Op> template <typename Op>
struct DefaultScalarVector { struct ScalarVector {
Op op; Op op;
DefaultScalarVector(Op op_) : op(op_) {} ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a; T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) { while (size-- > 0) {
*dst = op(scalar, *b); *dst = op(scalar, *b);
dst++; dst++;
@ -156,13 +172,22 @@ struct DefaultScalarVector {
} }
}; };
template <typename T, typename U, typename Op> template <typename Op>
struct DefaultVectorVector { struct VectorVector {
Op op; Op op;
DefaultVectorVector(Op op_) : op(op_) {} VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) { void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) { while (size-- > 0) {
*dst = op(*a, *b); *dst = op(*a, *b);
dst++; dst++;
@ -277,21 +302,8 @@ void binary_op_dispatch_dims(
} }
} }
template < template <typename T, typename U, typename Op>
typename T, void binary_op(const array& a, const array& b, array& out, Op op) {
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt);
@ -303,19 +315,19 @@ void binary_op(
// The full computation is scalar vector so delegate to the op // The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size()); ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return; return;
} }
// The full computation is vector scalar so delegate to the op // The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) { if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size()); VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return; return;
} }
// The full computation is vector vector so delegate to the op // The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) { if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size()); VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return; return;
} }
@ -376,15 +388,39 @@ void binary_op(
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides); a,
b,
out,
VectorVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides); a,
b,
out,
VectorScalar{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides); a,
b,
out,
ScalarVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break; break;
default: default:
binary_op_dispatch_dims<T, U, false>( binary_op_dispatch_dims<T, U, false>(
@ -393,134 +429,52 @@ void binary_op(
} }
} }
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op> template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) { void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op); binary_op<T, T>(a, b, out, op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
} }
template <typename... Ops> template <typename Op>
void binary(const array& a, const array& b, array& out, Ops... ops) { void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
binary_op<bool>(a, b, out, ops...); binary_op<bool>(a, b, out, op);
break; break;
case uint8: case uint8:
binary_op<uint8_t>(a, b, out, ops...); binary_op<uint8_t>(a, b, out, op);
break; break;
case uint16: case uint16:
binary_op<uint16_t>(a, b, out, ops...); binary_op<uint16_t>(a, b, out, op);
break; break;
case uint32: case uint32:
binary_op<uint32_t>(a, b, out, ops...); binary_op<uint32_t>(a, b, out, op);
break; break;
case uint64: case uint64:
binary_op<uint64_t>(a, b, out, ops...); binary_op<uint64_t>(a, b, out, op);
break; break;
case int8: case int8:
binary_op<int8_t>(a, b, out, ops...); binary_op<int8_t>(a, b, out, op);
break; break;
case int16: case int16:
binary_op<int16_t>(a, b, out, ops...); binary_op<int16_t>(a, b, out, op);
break; break;
case int32: case int32:
binary_op<int32_t>(a, b, out, ops...); binary_op<int32_t>(a, b, out, op);
break; break;
case int64: case int64:
binary_op<int64_t>(a, b, out, ops...); binary_op<int64_t>(a, b, out, op);
break; break;
case float16: case float16:
binary_op<float16_t>(a, b, out, ops...); binary_op<float16_t>(a, b, out, op);
break; break;
case float32: case float32:
binary_op<float>(a, b, out, ops...); binary_op<float>(a, b, out, op);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...); binary_op<bfloat16_t>(a, b, out, op);
break; break;
case complex64: case complex64:
binary_op<complex64_t>(a, b, out, ops...); binary_op<complex64_t>(a, b, out, op);
break; break;
} }
} }

View File

@ -0,0 +1,98 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define BINARY_SINGLE() \
template <typename T> \
T operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
}
#define DEFAULT_BINARY_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
BINARY_SINGLE() \
};
DEFAULT_BINARY_OP(Add, operator+)
DEFAULT_BINARY_OP(ArcTan2, atan2)
DEFAULT_BINARY_OP(Divide, operator/)
DEFAULT_BINARY_OP(Multiply, operator*)
DEFAULT_BINARY_OP(Subtract, operator-)
DEFAULT_BINARY_OP(LogicalAnd, operator&&)
DEFAULT_BINARY_OP(LogicalOr, operator||)
DEFAULT_BINARY_OP(BitwiseAnd, operator&)
DEFAULT_BINARY_OP(BitwiseOr, operator|)
DEFAULT_BINARY_OP(BitwiseXor, operator^)
DEFAULT_BINARY_OP(LeftShift, operator<<)
DEFAULT_BINARY_OP(RightShift, operator>>)
DEFAULT_BINARY_OP(Remainder, remainder)
DEFAULT_BINARY_OP(Maximum, maximum)
DEFAULT_BINARY_OP(Minimum, minimum)
DEFAULT_BINARY_OP(Power, pow)
#define DEFAULT_BOOL_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
template <typename T> \
bool operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
} \
};
DEFAULT_BOOL_OP(Equal, operator==)
DEFAULT_BOOL_OP(Greater, operator>)
DEFAULT_BOOL_OP(GreaterEqual, operator>=)
DEFAULT_BOOL_OP(Less, operator<)
DEFAULT_BOOL_OP(LessEqual, operator<=)
DEFAULT_BOOL_OP(NotEqual, operator!=)
struct NaNEqual {
template <int N, typename T>
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
return x == y || (isnan(x) && isnan(y));
}
template <typename T>
bool operator()(T x, T y) {
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
}
};
struct LogAddExp {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
auto maxval = maximum(x, y);
auto minval = minimum(x, y);
auto mask = minval == -inf || maxval == inf;
auto out = maxval + log1p(exp(minval - maxval));
return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
}
BINARY_SINGLE()
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
.value;
}
template <int N, typename T>
Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
return select(condition, x, y);
}
};
} // namespace mlx::core::detail

View File

@ -64,7 +64,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
} }
} }
void Cholesky::eval(const std::vector<array>& inputs, array& output) { void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32."); throw std::runtime_error("[Cholesky::eval] only supports float32.");
} }

View File

@ -5,7 +5,8 @@
// clang-format off // clang-format off
#include "mlx/types/half_types.h" #include "mlx/types/half_types.h"
#include "mlx/types/complex.h" #include "mlx/types/complex.h"
#include "mlx/backend/common/ops.h" #include "mlx/backend/common/unary_ops.h"
#include "mlx/backend/common/binary_ops.h"
// clang-format on // clang-format on
const char* get_kernel_preamble(); const char* get_kernel_preamble();

View File

@ -4,6 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@ -23,6 +24,7 @@ template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) { void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }

View File

@ -21,98 +21,9 @@
namespace mlx::core { namespace mlx::core {
DEFAULT(Abs)
DEFAULT(Add)
DEFAULT(Arange)
DEFAULT(ArcCos)
DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(BlockMaskedMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution) DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(Multiply)
DEFAULT(Negative)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce) DEFAULT(Reduce)
DEFAULT(Round)
DEFAULT(Scan) DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace { namespace {

View File

@ -45,7 +45,9 @@ void ssyevd(
} // namespace } // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) { void Eigh::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0]; const auto& a = inputs[0];
auto& values = outputs[0]; auto& values = outputs[0];

View File

@ -8,7 +8,7 @@
namespace mlx::core { namespace mlx::core {
void FFT::eval(const std::vector<array>& inputs, array& out) { void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in( std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end()); in.strides().begin(), in.strides().end());

View File

@ -82,7 +82,7 @@ void hadamard(array& out, int n, int m, float scale) {
} }
} }
void Hadamard::eval(const std::vector<array>& inputs, array& out) { void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];

View File

@ -162,7 +162,7 @@ void dispatch_gather(
} }
} }
void Gather::eval(const std::vector<array>& inputs, array& out) { void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
@ -337,7 +337,7 @@ void dispatch_scatter(
} }
} }
void Scatter::eval(const std::vector<array>& inputs, array& out) { void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2); assert(inputs.size() >= 2);
auto& src = inputs[0]; auto& src = inputs[0];

View File

@ -110,7 +110,7 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
} }
} }
void Inverse::eval(const std::vector<array>& inputs, array& output) { void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32."); throw std::runtime_error("[Inverse::eval] only supports float32.");
} }

View File

@ -11,7 +11,7 @@
#define lapack_complex_double std::complex<double> #define lapack_complex_double std::complex<double>
#endif #endif
#ifdef ACCELERATE_NEW_LAPACK #ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h> #include <cblas.h>

View File

@ -1,12 +1,9 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert>
#include <utility> #include <utility>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h" #include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
namespace { namespace {
@ -51,11 +48,4 @@ void load(
} }
} }
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -53,7 +53,7 @@ inline void mask_matrix(
} // namespace } // namespace
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) { if (out.dtype() != float32) {
throw std::runtime_error( throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32."); "[BlockMaskedMM::eval] Currently only supports float32.");
@ -210,7 +210,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void GatherMM::eval(const std::vector<array>& inputs, array& out) { void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) { if (out.dtype() != float32) {
throw std::runtime_error( throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32."); "[GatherMM::eval] Currently only supports float32.");

View File

@ -1,680 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
namespace mlx::core::detail {
namespace {
constexpr float inf = std::numeric_limits<float>::infinity();
} // namespace
typedef union {
int i;
float f;
} IntOrFloat;
inline float fast_exp(float x) {
if (x == -std::numeric_limits<float>::infinity()) {
return 0.0f;
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
return x;
}
x *= 1.442695; // multiply with log_2(e)
float ipart, fpart;
IntOrFloat epart;
x = std::max(-80.f, std::min(x, 80.f));
ipart = std::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart.i = (int(ipart) + 127) << 23;
return epart.f * x;
}
inline float fast_erf(float a) {
float r, s, t, u;
t = std::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = std::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = std::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = std::fma(r, s, u);
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = std::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - std::exp(r);
r = std::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = std::fma(r, a, a);
}
return r;
}
inline float fast_erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
}
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
}
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
}
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
}
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
}
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
}
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
}
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
}
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
}
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
}
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
}
complex64_t operator()(complex64_t x) {
return std::exp(x);
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
}
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
}
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
}
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
}
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + fast_exp(-x));
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
}
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
}
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
}
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
}
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
}
};
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
if constexpr (std::is_integral_v<T>) {
// isnan always returns false for integers, and MSVC refuses to compile.
return x == y;
} else {
return x == y || (std::isnan(x) && std::isnan(y));
}
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct Maximum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return (x > y) ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
}
};
struct Minimum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return x < y ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return x < y ? x : y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = Maximum()(x, y);
auto minval = Minimum()(x, y);
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
};
struct Power {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
}
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
}
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
}
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
}
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
}
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
}
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
}
};
} // namespace mlx::core::detail

View File

@ -9,10 +9,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/arange.h" #include "mlx/backend/common/arange.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h" #include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h" #include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -58,112 +57,64 @@ int64_t compute_dynamic_offset(
} }
} }
void Abs::eval(const std::vector<array>& inputs, array& out) { void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); eval(inputs, out);
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Abs());
} }
void Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Copy::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Split::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
} }
void Arange::eval(const std::vector<array>& inputs, array& out) { void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_); arange(inputs, out, start_, step_);
} }
void ArcCos::eval(const std::vector<array>& inputs, array& out) { void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
"[arccos] Cannot compute inverse cosine of elements in array"
" with non floating point type.");
}
}
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
" array with non floating point type.");
}
}
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
"[arcsin] Cannot compute inverse sine of elements in array"
" with non floating point type.");
}
}
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
" array with non floating point type.");
}
}
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
"[arctan] Cannot compute inverse tangent of elements in array"
" with non floating point type.");
}
}
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
" array with non floating point type.");
}
}
void AsType::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype); copy(in, out, ctype);
} }
void Ceil::eval(const std::vector<array>& inputs, array& out) { void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes; std::vector<int> sizes;
sizes.push_back(0); sizes.push_back(0);
for (auto& p : inputs) { for (auto& p : inputs) {
@ -187,17 +138,6 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -209,94 +149,6 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
"[cos] Cannot compute cosine of elements in array"
" with non floating point type.");
}
}
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
"[cosh] Cannot compute hyperbolic cosine of elements in array"
" with non floating point type.");
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) { void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out); reshape(inputs[0], out);
} }
@ -305,18 +157,7 @@ void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out); reshape(inputs[0], out);
} }
void Floor::eval(const std::vector<array>& inputs, array& out) { void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
assert(in.dtype() == out.dtype()); assert(in.dtype() == out.dtype());
@ -331,57 +172,14 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype); copy(in, out, ctype);
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) { void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag()); assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
} }
void Log::eval(const std::vector<array>& inputs, array& out) { void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
} else {
throw std::invalid_argument(
"[log] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Pad::eval(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array // Inputs must be base input array and scalar val array
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& in = inputs[0]; auto& in = inputs[0];
@ -412,7 +210,7 @@ void Pad::eval(const std::vector<array>& inputs, array& out) {
copy_inplace(in, out_slice, CopyType::GeneralGeneral); copy_inplace(in, out_slice, CopyType::GeneralGeneral);
} }
void RandomBits::eval(const std::vector<array>& inputs, array& out) { void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2) // keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...) // out has shape (N1, ..., NK, M1, M2, ...)
@ -460,71 +258,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) { void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out); reshape(inputs[0], out);
} }
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
"[sigmoid] Cannot sigmoid of elements in array with"
" non floating point type.");
}
}
void Sign::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
"[sin] Cannot compute sine of elements in array"
" with non floating point type.");
}
}
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
"[sinh] Cannot compute hyperbolic sine of elements in array"
" with non floating point type.");
}
}
void Slice::eval(const std::vector<array>& inputs, array& out) { void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (out.size() == 0) { if (out.size() == 0) {
@ -596,7 +333,7 @@ void DynamicSliceUpdate::eval_cpu(
/* CopyType ctype = */ CopyType::GeneralGeneral); /* CopyType ctype = */ CopyType::GeneralGeneral);
} }
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(nullptr);
@ -632,46 +369,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
/* CopyType ctype = */ CopyType::GeneralGeneral); /* CopyType ctype = */ CopyType::GeneralGeneral);
} }
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
"[tan] Cannot compute tangent of elements in array"
" with non floating point type.");
}
}
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(
"[tanh] Cannot compute hyperbolic tangent of elements in array"
" with non floating point type.");
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) { void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];

View File

@ -149,7 +149,9 @@ void qrf_impl(const array& a, array& q, array& r) {
allocator::free(tau); allocator::free(tau);
} }
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) { void QRF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) { if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32."); throw std::runtime_error("[QRF::eval] only supports float32.");
} }

View File

@ -3,7 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h" #include "mlx/backend/common/simd/simd.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -151,6 +151,78 @@ void _qmm_t(
} }
} }
template <int bits, int S>
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
constexpr int bitmask = (1 << bits) - 1;
simd::Simd<uint32_t, S> wi;
if constexpr (bits == 4 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & bitmask;
} else if constexpr (bits == 8 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
wi = simd::Simd<uint32_t, S>(l, r);
wi = wi >> shifts;
wi = wi & bitmask;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
return wi;
}
template <typename T, int bits, int group_size>
void _qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
w_local += packs_per_simd;
wf = wf * scale;
wf = wf + bias;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
acc = acc + x_simd * wf;
x_local += S;
}
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T, int bits, int group_size> template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose( void _qmm_dispatch_transpose(
T* result, T* result,
@ -163,9 +235,14 @@ void _qmm_dispatch_transpose(
int K, int K,
bool transposed_w) { bool transposed_w) {
if (transposed_w) { if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K); // the simd size must be a multiple of the number of elements per word
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else { } else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K); _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
} else {
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} }
} }
@ -249,13 +326,13 @@ void _qmm_dispatch(
int group_size, int group_size,
bool transposed_w) { bool transposed_w) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.shape(-2); int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1); int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2); int batch_size = x.size() / (K * M);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
@ -384,7 +461,7 @@ void _bs_qmm_dispatch(
} // namespace } // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4); assert(inputs.size() == 4);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
} }
void GatherQMM::eval(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6); assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];

View File

@ -2,6 +2,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/binary_ops.h"
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -61,7 +62,7 @@ void select_op(
} // namespace } // namespace
void Select::eval(const std::vector<array>& inputs, array& out) { void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
const auto& condition = inputs[0]; const auto& condition = inputs[0];
const auto& a = inputs[1]; const auto& a = inputs[1];

View File

@ -0,0 +1,56 @@
#pragma once
#include "mlx/backend/common/simd/base_simd.h"
#if MLX_SIMD_LIBRARY_VERSION < 6
#include "mlx/backend/common/simd/neon_fp16_simd.h"
#endif
namespace mlx::core::simd {
#if MLX_SIMD_LIBRARY_VERSION >= 6
constexpr int N = 8;
template <int N>
struct ScalarT<float16_t, N> {
using v = _Float16;
};
#endif
template <>
static constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
Simd<float, N> in = v; \
return op(in); \
}
SIMD_FP16_DEFAULT_UNARY(acos)
SIMD_FP16_DEFAULT_UNARY(acosh)
SIMD_FP16_DEFAULT_UNARY(asin)
SIMD_FP16_DEFAULT_UNARY(asinh)
SIMD_FP16_DEFAULT_UNARY(atan)
SIMD_FP16_DEFAULT_UNARY(atanh)
SIMD_FP16_DEFAULT_UNARY(cosh)
SIMD_FP16_DEFAULT_UNARY(expm1)
SIMD_FP16_DEFAULT_UNARY(log)
SIMD_FP16_DEFAULT_UNARY(log2)
SIMD_FP16_DEFAULT_UNARY(log10)
SIMD_FP16_DEFAULT_UNARY(log1p)
SIMD_FP16_DEFAULT_UNARY(sinh)
SIMD_FP16_DEFAULT_UNARY(tan)
SIMD_FP16_DEFAULT_UNARY(tanh)
#define SIMD_FP16_DEFAULT_BINARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
Simd<float, N> a = x; \
Simd<float, N> b = y; \
return op(a, b); \
}
SIMD_FP16_DEFAULT_BINARY(atan2)
SIMD_FP16_DEFAULT_BINARY(remainder)
SIMD_FP16_DEFAULT_BINARY(pow)
} // namespace mlx::core::simd

View File

@ -0,0 +1,291 @@
#pragma once
#include <simd/math.h>
#include <simd/vector.h>
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/common/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__TV_OS_VERSION_MIN_REQUIRED >= 180000
#define MLX_SIMD_LIBRARY_VERSION 6
#else
#define MLX_SIMD_LIBRARY_VERSION 5
#endif
namespace mlx::core::simd {
// Apple simd namespace
namespace asd = ::simd;
// This indirection is needed to remap certain types to ones that accelerate
// SIMD can handle
template <typename T, int N>
struct ScalarT {
using v = T;
};
template <int N>
struct ScalarT<bool, N> {
using v = char;
};
template <int N>
struct ScalarT<int8_t, N> {
using v = char;
};
template <int N>
struct ScalarT<uint64_t, N> {
using v = unsigned long;
};
template <int N>
struct ScalarT<int64_t, N> {
using v = long;
};
template <typename T, int N>
struct Simd {
static constexpr int size = N;
using scalar_t = typename ScalarT<T, N>::v;
Simd<T, N>() {}
template <typename U>
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
template <typename U>
Simd<T, N>(U v) : value(v){};
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
x.value, y.value);
};
T operator[](int idx) const {
return reinterpret_cast<const T*>(&value)[idx];
}
T& operator[](int idx) {
return reinterpret_cast<T*>(&value)[idx];
}
typename asd::Vector<scalar_t, N>::packed_t value;
};
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \
Simd<T, N> name(Simd<T, N> v) { \
return op(v.value); \
}
SIMD_DEFAULT_UNARY(abs, asd::abs)
SIMD_DEFAULT_UNARY(floor, asd::floor)
SIMD_DEFAULT_UNARY(acos, asd::acos)
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
SIMD_DEFAULT_UNARY(asin, asd::asin)
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
SIMD_DEFAULT_UNARY(atan, asd::atan)
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
SIMD_DEFAULT_UNARY(log, asd::log)
SIMD_DEFAULT_UNARY(log2, asd::log2)
SIMD_DEFAULT_UNARY(log10, asd::log10)
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
SIMD_DEFAULT_UNARY(rint, asd::rint)
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
SIMD_DEFAULT_UNARY(recip, asd::recip)
SIMD_DEFAULT_UNARY(tan, asd::tan)
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
template <typename T, int N>
Simd<T, N> operator-(Simd<T, N> v) {
return -v.value;
}
template <typename T, int N>
Simd<bool, N> isnan(Simd<T, N> v) {
return asd::convert<char>(v.value != v.value);
}
// No simd_boolN in accelerate, use int8_t instead
template <typename T, int N>
Simd<bool, N> operator!(Simd<T, N> v) {
return asd::convert<char>(!v.value);
}
#define SIMD_DEFAULT_BINARY(OP) \
template <typename T, typename U, int N> \
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
} \
template <typename T1, typename T2, int N> \
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
} \
template <typename T1, typename T2, int N> \
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
}
SIMD_DEFAULT_BINARY(+)
SIMD_DEFAULT_BINARY(-)
SIMD_DEFAULT_BINARY(/)
SIMD_DEFAULT_BINARY(*)
SIMD_DEFAULT_BINARY(<<)
SIMD_DEFAULT_BINARY(>>)
SIMD_DEFAULT_BINARY(|)
SIMD_DEFAULT_BINARY(^)
SIMD_DEFAULT_BINARY(&)
SIMD_DEFAULT_BINARY(&&)
SIMD_DEFAULT_BINARY(||)
#define SIMD_DEFAULT_COMPARISONS(OP) \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
return asd::convert<char>(a.value OP b); \
} \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
return asd::convert<char>(a OP b.value); \
} \
template <int N, typename T1, typename T2> \
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
return asd::convert<char>(a.value OP b.value); \
}
SIMD_DEFAULT_COMPARISONS(>)
SIMD_DEFAULT_COMPARISONS(<)
SIMD_DEFAULT_COMPARISONS(>=)
SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
Simd<T, N> r;
if constexpr (!std::is_integral_v<T>) {
r = asd::remainder(a.value, b.value);
} else {
r = a - b * (a / b);
}
if constexpr (std::is_signed_v<T>) {
auto mask = r != 0 && (r < 0 != b < 0);
r = select(mask, r + b, r);
}
return r;
}
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
} else if constexpr (sizeof(T1) == 4) {
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
} else {
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
}
}
template <typename T, int N>
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
if constexpr (!std::is_integral_v<T>) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
exp = exp >> 1;
}
return res;
}
}
template <typename T, int N>
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
return asd::clamp(v.value, min.value, max.value);
}
template <typename T, typename U, int N>
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
}
// Reductions
template <typename T, int N>
bool any(Simd<T, N> x) {
return asd::any(x.value);
}
template <typename T, int N>
T sum(Simd<T, N> x) {
return asd::reduce_add(x.value);
}
template <typename T, int N>
T max(Simd<T, N> x) {
return asd::reduce_max(x.value);
}
template <typename T, int N>
T min(Simd<T, N> x) {
return asd::reduce_min(x.value);
}
} // namespace mlx::core::simd
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "mlx/backend/common/simd/accelerate_fp16_simd.h"
#endif

View File

@ -0,0 +1,252 @@
#pragma once
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <complex>
namespace mlx::core::simd {
template <typename T, int N>
struct Simd;
template <typename T>
static constexpr int max_size = 1;
template <typename T>
struct Simd<T, 1> {
static constexpr int size = 1;
T value;
Simd() {}
template <typename U>
Simd(Simd<U, 1> v) : value(v.value) {}
template <typename U>
Simd(U v) : value(v) {}
};
template <typename T, int N>
Simd<T, N> load(const T* x) {
return *(Simd<T, N>*)x;
}
template <typename T, int N>
void store(T* dst, Simd<T, N> x) {
// Maintain invariant that bool is either 0 or 1 as
// simd comparison ops set all bits in the result to 1
if constexpr (std::is_same_v<T, bool> && N > 1) {
x = x & 1;
}
*(Simd<T, N>*)dst = x;
}
template <typename, typename = void>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
true;
template <typename T>
Simd<T, 1> rint(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
} else {
return Simd<T, 1>{std::rint(in.value)};
}
}
template <typename T>
Simd<T, 1> rsqrt(Simd<T, 1> in) {
return T(1.0) / sqrt(in);
}
template <typename T>
Simd<T, 1> recip(Simd<T, 1> in) {
return T(1.0) / in;
}
#define DEFAULT_UNARY(name, op) \
template <typename T> \
Simd<T, 1> name(Simd<T, 1> in) { \
return op(in.value); \
}
DEFAULT_UNARY(operator-, std::negate{})
DEFAULT_UNARY(operator!, std::logical_not{})
DEFAULT_UNARY(abs, std::abs)
DEFAULT_UNARY(acos, std::acos)
DEFAULT_UNARY(acosh, std::acosh)
DEFAULT_UNARY(asin, std::asin)
DEFAULT_UNARY(asinh, std::asinh)
DEFAULT_UNARY(atan, std::atan)
DEFAULT_UNARY(atanh, std::atanh)
DEFAULT_UNARY(ceil, std::ceil)
DEFAULT_UNARY(conj, std::conj)
DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
return std::real(in.value);
}
template <typename T>
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
return std::imag(in.value);
}
template <typename T>
Simd<bool, 1> isnan(Simd<T, 1> in) {
return std::isnan(in.value);
}
#define DEFAULT_BINARY(OP) \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
->Simd<decltype(a.value OP b.value), 1> { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
return a.value OP b; \
}
DEFAULT_BINARY(+)
DEFAULT_BINARY(-)
DEFAULT_BINARY(*)
DEFAULT_BINARY(/)
DEFAULT_BINARY(<<)
DEFAULT_BINARY(>>)
DEFAULT_BINARY(|)
DEFAULT_BINARY(^)
DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
T r;
if constexpr (std::is_integral_v<T>) {
r = a % b;
} else {
r = std::remainder(a, b);
}
if constexpr (std::is_signed_v<T>) {
if (r != 0 && (r < 0 != b < 0)) {
r += b;
}
}
return r;
}
template <typename T>
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a > b) ? a : b;
}
template <typename T>
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a < b) ? a : b;
}
template <typename T>
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
T base = a.value;
T exp = b.value;
if constexpr (!std::is_integral_v<T>) {
return std::pow(base, exp);
} else {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
}
template <typename T>
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
return std::atan2(a.value, b.value);
}
#define DEFAULT_COMPARISONS(OP) \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
return a.value OP b; \
}
DEFAULT_COMPARISONS(>)
DEFAULT_COMPARISONS(<)
DEFAULT_COMPARISONS(>=)
DEFAULT_COMPARISONS(<=)
DEFAULT_COMPARISONS(==)
DEFAULT_COMPARISONS(!=)
template <typename MaskT, typename T>
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
return mask.value ? x.value : y.value;
}
template <typename T>
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
return std::clamp(v.value, min.value, max.value);
}
template <typename T, typename U>
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
}
// Reductions
#define DEFAULT_REDUCTION(name, type) \
template <typename T> \
type name(Simd<T, 1> x) { \
return x.value; \
}
DEFAULT_REDUCTION(max, T)
DEFAULT_REDUCTION(min, T)
DEFAULT_REDUCTION(sum, T)
DEFAULT_REDUCTION(any, bool)
DEFAULT_REDUCTION(all, bool)
} // namespace mlx::core::simd

View File

@ -0,0 +1,193 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/simd/type.h"
namespace mlx::core::simd {
constexpr float inf = std::numeric_limits<float>::infinity();
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
template <typename T, int N>
Simd<T, N> exp(Simd<T, N> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{std::exp(in.value)};
} else {
Simd<float, N> x_init = in;
auto x = x_init * 1.442695f; // multiply with log_2(e)
Simd<float, N> ipart, fpart;
ipart = floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = fma(x, fpart, 1.339887440266574e-3f);
x = fma(x, fpart, 9.618437357674640e-3f);
x = fma(x, fpart, 5.550332471162809e-2f);
x = fma(x, fpart, 2.402264791363012e-1f);
x = fma(x, fpart, 6.931472028550421e-1f);
x = fma(x, fpart, 1.000000000000000f);
// generate 2**ipart in the floating point representation using integer
// bitshifting
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
// Deal with NaN and Inf
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
result = select(x_init < -88.0f, Simd<float, N>(0), result);
return Simd<T, N>(result);
}
}
/* Implementation from:
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
* which originally came from the Cephes math library.
*/
template <bool Sine, typename T, int N>
Simd<T, N> sincos(Simd<T, N> in) {
auto sign_mask_sin = in < 0;
in = abs(in);
Simd<float, N> x = in;
// scale by 4/Pi
auto y = x * 1.27323954473516f;
// store the integer part of y in mm0
Simd<uint32_t, N> emm2 = y;
// j=(j+1) & (~1) (see the cephes sources)
emm2 = emm2 + 1;
emm2 = emm2 & ~1;
y = emm2;
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask = (emm2 & 2) != 0;
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
x = fma(y, Simd<float, N>(-0.78515625f), x);
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2
auto z = x * x;
auto y1 =
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
y1 = fma(y1, z, 4.166664568298827e-2f);
y2 = fma(y2, z, -1.6666654611e-1f);
y1 = y1 * z;
y2 = y2 * z;
y1 = y1 * z;
y2 = fma(x, y2, x);
y1 = fma(z, Simd<float, N>(-0.5f), y1);
y1 = y1 + 1.0f;
if constexpr (Sine) {
auto ys = select(poly_mask, y1, y2);
return select(sign_mask_sin, -ys, ys);
} else {
auto yc = select(poly_mask, y2, y1);
return select(sign_mask_cos, yc, -yc);
}
}
template <typename T, int N>
Simd<T, N> sin(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::sin(x.value);
} else {
return sincos<true>(x);
}
}
template <typename T, int N>
Simd<T, N> cos(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::cos(x.value);
} else {
return sincos<false>(x);
}
}
template <typename T, int N>
Simd<T, N> erf(Simd<T, N> x) {
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
Simd<float, N> v = x;
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
r = fma(r, t, 1.421413741f);
r = fma(r, t, -0.284496736f);
r = fma(r, t, 0.254829592f);
auto e = -exp(-v * v);
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
return select(x > 0, result, -result);
}
template <typename T, int N>
Simd<T, N> erfinv(Simd<T, N> a_) {
Simd<float, N> a = a_;
auto t = fma(a, 0.0f - a, 1.0f);
t = log(t);
auto lhs = [](auto t) {
Simd<float, N> p;
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
};
auto rhs = [](auto t) {
Simd<float, N> p;
p = 5.43877832e-9f; // 0x1.75c000p-28
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
};
auto thresh = 6.125f;
// Compute both branches and select if N > 1
if constexpr (N == 1) {
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
return a * lhs(t);
} else { // maximum ulp error = 2.35002
return a * rhs(t);
}
} else {
return a * select(t > thresh, lhs(t), rhs(t));
}
}
} // namespace mlx::core::simd

View File

@ -0,0 +1,204 @@
#pragma once
#include <arm_neon.h>
#include "mlx/backend/common/simd/base_simd.h"
namespace mlx::core::simd {
constexpr int N = 8;
template <>
struct Simd<float16_t, N> {
static constexpr int size = N;
using scalar_t = float16_t;
Simd<float16_t, N>() {}
template <typename U>
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
Simd<float16_t, N>(float16x8_t v) : value(v){};
Simd<float16_t, N>(Simd<float, N> other) {
auto f32x4_a = *(float32x4_t*)(&other);
auto f32x4_b = *((float32x4_t*)(&other) + 1);
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
};
Simd<float16_t, N>(Simd<uint16_t, N> other) {
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
};
operator Simd<int16_t, N>() {
auto v = vcvtq_s16_f16(value);
return load<int16_t, N>((int16_t*)&v);
};
operator Simd<float, N>() {
float32x4x2_t v;
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
v.val[1] = vcvt_high_f32_f16(value);
return load<float, N>((float*)&v);
}
float16_t operator[](int idx) const {
return reinterpret_cast<const float16_t*>(&value)[idx];
}
float16_t& operator[](int idx) {
return reinterpret_cast<float16_t*>(&value)[idx];
}
float16x8_t value;
};
#define DEFINE_NEON_UNARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
return Simd<float16_t, N>{op(a.value)}; \
}
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
#define DEFINE_NEON_BINARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
return op(a.value, b.value); \
} \
template <typename T> \
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
return op(a.value, Simd<float16_t, N>(b).value); \
} \
template <typename T> \
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
return op(Simd<float16_t, N>(a).value, b.value); \
}
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
auto out = vceqzq_f16(v.value);
return Simd<uint16_t, N>(*(uint16_t*)&out);
}
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
return vnegq_f16(v.value);
}
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
#define DEFINE_NEON_COMPARISON(Op, op) \
template <typename T> \
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
auto out = op(a.value, Simd<float16_t, N>(b).value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
template <typename T> \
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
auto out = op(Simd<float16_t, N>(a).value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
inline Simd<bool, N> operator Op( \
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
auto out = op(a.value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
}
DEFINE_NEON_COMPARISON(==, vceqq_f16)
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
DEFINE_NEON_COMPARISON(<, vcltq_f16)
template <typename T>
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
return !(a == b);
}
template <typename T>
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<float16_t, N> operator||(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
inline Simd<float16_t, N> operator&&(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <>
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
return v != v;
}
template <>
inline Simd<float16_t, N>
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
return minimum(maximum(v, min), max);
}
template <typename T>
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
}
template <typename MaskT>
Simd<float16_t, N>
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
}
// Reductions
inline float16_t max(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t min(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmin_f16(y, y);
y = vpmin_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t sum(Simd<float16_t, N> x) {
float16x4_t y;
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpadd_f16(y, y);
y = vpadd_f16(y, y);
return vget_lane_f16(y, 0);
}
} // namespace mlx::core::simd

View File

@ -0,0 +1,4 @@
#pragma once
#include "mlx/backend/common/simd/math.h"
#include "mlx/backend/common/simd/type.h"

View File

@ -0,0 +1,7 @@
#pragma once
#include "mlx/backend/common/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE
#include "mlx/backend/common/simd/accelerate_simd.h"
#endif

View File

@ -4,61 +4,107 @@
#include <cmath> #include <cmath>
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT> template <typename T, typename AccT>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* in_ptr = in.data<T>(); const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>(); T* out_ptr = out.data<T>();
int N = in.shape().back(); int M = in.shape().back();
int M = in.data_size() / N; int L = in.data_size() / M;
const T* current_in_ptr; const T* current_in_ptr;
T* current_out_ptr; T* current_out_ptr;
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum // Find the maximum
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
AccT maximum = *current_in_ptr; Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
for (int j = 0; j < N; j++, current_in_ptr++) { size_t s = M;
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr) while (s >= N) {
: maximum; Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
} }
// Compute the normalizer and the exponentials // Compute the normalizer and the exponentials
AccT normalizer = 0; Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { s = M;
AccT expv = std::exp(*current_in_ptr - maximum); while (s >= N) {
normalizer += expv; Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
if constexpr (std::is_same<T, AccT>::value) { vexp = exp(vexp - maximum);
*current_out_ptr = expv; if constexpr (same_t) {
store(current_out_ptr, vexp);
} }
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
} }
normalizer = 1 / normalizer; normalizer = 1 / normalizer;
// Normalize // Normalize
current_in_ptr = in_ptr;
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) { current_in_ptr = in_ptr;
if constexpr (std::is_same<T, AccT>::value) { s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer; *current_out_ptr *= normalizer;
} else { } else {
auto v = std::exp(*current_in_ptr - maximum); AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer); *current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++; current_in_ptr++;
} }
current_out_ptr++;
} }
} }
} }
} // namespace } // namespace
void Softmax::eval(const std::vector<array>& inputs, array& out) { void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
@ -97,7 +143,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
case int16: case int16:
case int32: case int32:
case int64: case int64:
throw std::invalid_argument( throw std::runtime_error(
"Softmax is defined only for floating point types"); "Softmax is defined only for floating point types");
break; break;
case float32: case float32:

View File

@ -287,7 +287,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
} // namespace } // namespace
void ArgSort::eval(const std::vector<array>& inputs, array& out) { void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -321,7 +321,7 @@ void ArgSort::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Sort::eval(const std::vector<array>& inputs, array& out) { void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -355,7 +355,7 @@ void Sort::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void ArgPartition::eval(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -389,7 +389,7 @@ void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Partition::eval(const std::vector<array>& inputs, array& out) { void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];

View File

@ -137,7 +137,9 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
} }
} }
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) { void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) { if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32."); throw std::runtime_error("[SVD::eval] only supports float32.");
} }

View File

@ -3,8 +3,8 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {

View File

@ -0,0 +1,285 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/unary_ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh());
}
} // namespace mlx::core

View File

@ -4,6 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -38,8 +39,19 @@ void unary_op(const array& a, array& out, Op op) {
if (a.flags().contiguous) { if (a.flags().contiguous) {
set_unary_output_data(a, out); set_unary_output_data(a, out);
U* dst = out.data<U>(); U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) { constexpr int N = simd::max_size<T>;
dst[i] = op(a_ptr[i]); size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
size -= N;
a_ptr += N;
dst += N;
}
while (size > 0) {
*dst = op(*a_ptr);
size--;
dst++;
a_ptr++;
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));

View File

@ -0,0 +1,108 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define SINGLE() \
template <typename T> \
T operator()(T x) { \
return (*this)(Simd<T, 1>(x)).value; \
}
#define DEFAULT_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x) { \
return simd::op(x); \
} \
SINGLE() \
};
DEFAULT_OP(Abs, abs)
DEFAULT_OP(ArcCos, acos)
DEFAULT_OP(ArcCosh, acosh)
DEFAULT_OP(ArcSin, asin)
DEFAULT_OP(ArcSinh, asinh)
DEFAULT_OP(ArcTan, atan)
DEFAULT_OP(ArcTanh, atanh)
DEFAULT_OP(Ceil, ceil)
DEFAULT_OP(Conjugate, conj)
DEFAULT_OP(Cos, cos)
DEFAULT_OP(Cosh, cosh)
DEFAULT_OP(Erf, erf)
DEFAULT_OP(ErfInv, erfinv)
DEFAULT_OP(Exp, exp)
DEFAULT_OP(Expm1, expm1)
DEFAULT_OP(Floor, floor);
DEFAULT_OP(Log, log);
DEFAULT_OP(Log2, log2);
DEFAULT_OP(Log10, log10);
DEFAULT_OP(Log1p, log1p);
DEFAULT_OP(LogicalNot, operator!)
DEFAULT_OP(Negative, operator-)
DEFAULT_OP(Round, rint);
DEFAULT_OP(Sin, sin)
DEFAULT_OP(Sinh, sinh)
DEFAULT_OP(Sqrt, sqrt)
DEFAULT_OP(Rsqrt, rsqrt)
DEFAULT_OP(Tan, tan)
DEFAULT_OP(Tanh, tanh)
struct Imag {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::imag(x);
}
SINGLE()
};
struct Real {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::real(x);
}
SINGLE()
};
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x));
}
SINGLE()
};
struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
}
}
SINGLE()
};
struct Square {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return x * x;
}
SINGLE()
};
} // namespace mlx::core::detail

View File

@ -163,9 +163,6 @@ class Abs : public UnaryPrimitive {
DEFINE_PRINT(Abs) DEFINE_PRINT(Abs)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Add : public UnaryPrimitive { class Add : public UnaryPrimitive {
@ -180,9 +177,6 @@ class Add : public UnaryPrimitive {
DEFINE_PRINT(Add) DEFINE_PRINT(Add)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class AddMM : public UnaryPrimitive { class AddMM : public UnaryPrimitive {
@ -226,8 +220,6 @@ class Arange : public UnaryPrimitive {
double start_; double start_;
double stop_; double stop_;
double step_; double step_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcCos : public UnaryPrimitive { class ArcCos : public UnaryPrimitive {
@ -242,9 +234,6 @@ class ArcCos : public UnaryPrimitive {
DEFINE_PRINT(ArcCos) DEFINE_PRINT(ArcCos)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcCosh : public UnaryPrimitive { class ArcCosh : public UnaryPrimitive {
@ -259,9 +248,6 @@ class ArcCosh : public UnaryPrimitive {
DEFINE_PRINT(ArcCosh) DEFINE_PRINT(ArcCosh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcSin : public UnaryPrimitive { class ArcSin : public UnaryPrimitive {
@ -276,9 +262,6 @@ class ArcSin : public UnaryPrimitive {
DEFINE_PRINT(ArcSin) DEFINE_PRINT(ArcSin)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcSinh : public UnaryPrimitive { class ArcSinh : public UnaryPrimitive {
@ -293,9 +276,6 @@ class ArcSinh : public UnaryPrimitive {
DEFINE_PRINT(ArcSinh) DEFINE_PRINT(ArcSinh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcTan : public UnaryPrimitive { class ArcTan : public UnaryPrimitive {
@ -310,9 +290,6 @@ class ArcTan : public UnaryPrimitive {
DEFINE_PRINT(ArcTan) DEFINE_PRINT(ArcTan)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcTan2 : public UnaryPrimitive { class ArcTan2 : public UnaryPrimitive {
@ -327,9 +304,6 @@ class ArcTan2 : public UnaryPrimitive {
DEFINE_PRINT(ArcTan2) DEFINE_PRINT(ArcTan2)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArcTanh : public UnaryPrimitive { class ArcTanh : public UnaryPrimitive {
@ -344,9 +318,6 @@ class ArcTanh : public UnaryPrimitive {
DEFINE_PRINT(ArcTanh) DEFINE_PRINT(ArcTanh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArgPartition : public UnaryPrimitive { class ArgPartition : public UnaryPrimitive {
@ -369,8 +340,6 @@ class ArgPartition : public UnaryPrimitive {
private: private:
int kth_; int kth_;
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArgReduce : public UnaryPrimitive { class ArgReduce : public UnaryPrimitive {
@ -398,8 +367,6 @@ class ArgReduce : public UnaryPrimitive {
private: private:
ReduceType reduce_type_; ReduceType reduce_type_;
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class ArgSort : public UnaryPrimitive { class ArgSort : public UnaryPrimitive {
@ -420,8 +387,6 @@ class ArgSort : public UnaryPrimitive {
private: private:
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class AsType : public UnaryPrimitive { class AsType : public UnaryPrimitive {
@ -443,8 +408,6 @@ class AsType : public UnaryPrimitive {
private: private:
Dtype dtype_; Dtype dtype_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class AsStrided : public UnaryPrimitive { class AsStrided : public UnaryPrimitive {
@ -518,8 +481,6 @@ class BlockMaskedMM : public UnaryPrimitive {
private: private:
int block_size_; int block_size_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class GatherMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive {
@ -537,9 +498,6 @@ class GatherMM : public UnaryPrimitive {
DEFINE_PRINT(GatherMM) DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class BroadcastAxes : public UnaryPrimitive { class BroadcastAxes : public UnaryPrimitive {
@ -603,9 +561,6 @@ class Ceil : public UnaryPrimitive {
DEFINE_PRINT(Ceil) DEFINE_PRINT(Ceil)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Compiled : public Primitive { class Compiled : public Primitive {
@ -669,8 +624,6 @@ class Concatenate : public UnaryPrimitive {
private: private:
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Conjugate : public UnaryPrimitive { class Conjugate : public UnaryPrimitive {
@ -684,9 +637,6 @@ class Conjugate : public UnaryPrimitive {
DEFINE_PRINT(Conjugate) DEFINE_PRINT(Conjugate)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Contiguous : public UnaryPrimitive { class Contiguous : public UnaryPrimitive {
@ -787,9 +737,6 @@ class Cos : public UnaryPrimitive {
DEFINE_PRINT(Cos) DEFINE_PRINT(Cos)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Cosh : public UnaryPrimitive { class Cosh : public UnaryPrimitive {
@ -804,9 +751,6 @@ class Cosh : public UnaryPrimitive {
DEFINE_PRINT(Cosh) DEFINE_PRINT(Cosh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class CustomTransforms : public Primitive { class CustomTransforms : public Primitive {
@ -894,9 +838,6 @@ class Divide : public UnaryPrimitive {
DEFINE_PRINT(Divide) DEFINE_PRINT(Divide)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class DivMod : public Primitive { class DivMod : public Primitive {
@ -915,9 +856,6 @@ class DivMod : public Primitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override { std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
return std::vector{inputs[0].shape(), inputs[0].shape()}; return std::vector{inputs[0].shape(), inputs[0].shape()};
} }
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
class Select : public UnaryPrimitive { class Select : public UnaryPrimitive {
@ -932,9 +870,6 @@ class Select : public UnaryPrimitive {
DEFINE_PRINT(Select) DEFINE_PRINT(Select)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Remainder : public UnaryPrimitive { class Remainder : public UnaryPrimitive {
@ -949,9 +884,6 @@ class Remainder : public UnaryPrimitive {
DEFINE_PRINT(Remainder) DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Equal : public UnaryPrimitive { class Equal : public UnaryPrimitive {
@ -979,7 +911,6 @@ class Equal : public UnaryPrimitive {
}; };
private: private:
void eval(const std::vector<array>& inputs, array& out);
bool equal_nan_; bool equal_nan_;
}; };
@ -995,9 +926,6 @@ class Erf : public UnaryPrimitive {
DEFINE_PRINT(Erf) DEFINE_PRINT(Erf)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ErfInv : public UnaryPrimitive { class ErfInv : public UnaryPrimitive {
@ -1012,9 +940,6 @@ class ErfInv : public UnaryPrimitive {
DEFINE_PRINT(ErfInv) DEFINE_PRINT(ErfInv)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Exp : public UnaryPrimitive { class Exp : public UnaryPrimitive {
@ -1029,9 +954,6 @@ class Exp : public UnaryPrimitive {
DEFINE_PRINT(Exp) DEFINE_PRINT(Exp)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Expm1 : public UnaryPrimitive { class Expm1 : public UnaryPrimitive {
@ -1045,9 +967,6 @@ class Expm1 : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Expm1) DEFINE_PRINT(Expm1)
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class ExpandDims : public UnaryPrimitive { class ExpandDims : public UnaryPrimitive {
@ -1100,8 +1019,6 @@ class FFT : public UnaryPrimitive {
std::vector<size_t> axes_; std::vector<size_t> axes_;
bool inverse_; bool inverse_;
bool real_; bool real_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Flatten : public UnaryPrimitive { class Flatten : public UnaryPrimitive {
@ -1141,9 +1058,6 @@ class Floor : public UnaryPrimitive {
DEFINE_PRINT(Floor) DEFINE_PRINT(Floor)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Full : public UnaryPrimitive { class Full : public UnaryPrimitive {
@ -1157,9 +1071,6 @@ class Full : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Full) DEFINE_PRINT(Full)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Gather : public UnaryPrimitive { class Gather : public UnaryPrimitive {
@ -1182,7 +1093,6 @@ class Gather : public UnaryPrimitive {
} }
private: private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_; std::vector<int> axes_;
Shape slice_sizes_; Shape slice_sizes_;
}; };
@ -1199,9 +1109,6 @@ class Greater : public UnaryPrimitive {
DEFINE_PRINT(Greater) DEFINE_PRINT(Greater)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class GreaterEqual : public UnaryPrimitive { class GreaterEqual : public UnaryPrimitive {
@ -1216,9 +1123,6 @@ class GreaterEqual : public UnaryPrimitive {
DEFINE_PRINT(GreaterEqual) DEFINE_PRINT(GreaterEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Hadamard : public UnaryPrimitive { class Hadamard : public UnaryPrimitive {
@ -1241,8 +1145,6 @@ class Hadamard : public UnaryPrimitive {
private: private:
float scale_; float scale_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Imag : public UnaryPrimitive { class Imag : public UnaryPrimitive {
@ -1271,9 +1173,6 @@ class Less : public UnaryPrimitive {
DEFINE_PRINT(Less) DEFINE_PRINT(Less)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class LessEqual : public UnaryPrimitive { class LessEqual : public UnaryPrimitive {
@ -1288,9 +1187,6 @@ class LessEqual : public UnaryPrimitive {
DEFINE_PRINT(LessEqual) DEFINE_PRINT(LessEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Load : public UnaryPrimitive { class Load : public UnaryPrimitive {
@ -1319,7 +1215,6 @@ class Load : public UnaryPrimitive {
static Stream io_stream = new_stream(Device::cpu); static Stream io_stream = new_stream(Device::cpu);
return io_stream; return io_stream;
}; };
void eval(const std::vector<array>& inputs, array& out);
std::shared_ptr<io::Reader> reader_; std::shared_ptr<io::Reader> reader_;
size_t offset_; size_t offset_;
bool swap_endianness_; bool swap_endianness_;
@ -1360,7 +1255,6 @@ class Log : public UnaryPrimitive {
private: private:
Base base_; Base base_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Log1p : public UnaryPrimitive { class Log1p : public UnaryPrimitive {
@ -1374,9 +1268,6 @@ class Log1p : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Log1p) DEFINE_PRINT(Log1p)
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class LogicalNot : public UnaryPrimitive { class LogicalNot : public UnaryPrimitive {
@ -1391,9 +1282,6 @@ class LogicalNot : public UnaryPrimitive {
DEFINE_PRINT(LogicalNot) DEFINE_PRINT(LogicalNot)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class LogicalAnd : public UnaryPrimitive { class LogicalAnd : public UnaryPrimitive {
@ -1408,9 +1296,6 @@ class LogicalAnd : public UnaryPrimitive {
DEFINE_PRINT(LogicalAnd) DEFINE_PRINT(LogicalAnd)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class LogicalOr : public UnaryPrimitive { class LogicalOr : public UnaryPrimitive {
@ -1425,9 +1310,6 @@ class LogicalOr : public UnaryPrimitive {
DEFINE_PRINT(LogicalOr) DEFINE_PRINT(LogicalOr)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class LogAddExp : public UnaryPrimitive { class LogAddExp : public UnaryPrimitive {
@ -1442,9 +1324,6 @@ class LogAddExp : public UnaryPrimitive {
DEFINE_PRINT(LogAddExp) DEFINE_PRINT(LogAddExp)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Matmul : public UnaryPrimitive { class Matmul : public UnaryPrimitive {
@ -1473,9 +1352,6 @@ class Maximum : public UnaryPrimitive {
DEFINE_PRINT(Maximum) DEFINE_PRINT(Maximum)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Minimum : public UnaryPrimitive { class Minimum : public UnaryPrimitive {
@ -1490,9 +1366,6 @@ class Minimum : public UnaryPrimitive {
DEFINE_PRINT(Minimum) DEFINE_PRINT(Minimum)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Multiply : public UnaryPrimitive { class Multiply : public UnaryPrimitive {
@ -1507,9 +1380,6 @@ class Multiply : public UnaryPrimitive {
DEFINE_PRINT(Multiply) DEFINE_PRINT(Multiply)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Negative : public UnaryPrimitive { class Negative : public UnaryPrimitive {
@ -1524,9 +1394,6 @@ class Negative : public UnaryPrimitive {
DEFINE_PRINT(Negative) DEFINE_PRINT(Negative)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class NotEqual : public UnaryPrimitive { class NotEqual : public UnaryPrimitive {
@ -1541,9 +1408,6 @@ class NotEqual : public UnaryPrimitive {
DEFINE_PRINT(NotEqual) DEFINE_PRINT(NotEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class NumberOfElements : public UnaryPrimitive { class NumberOfElements : public UnaryPrimitive {
@ -1606,8 +1470,6 @@ class Pad : public UnaryPrimitive {
std::vector<int> axes_; std::vector<int> axes_;
Shape low_pad_size_; Shape low_pad_size_;
Shape high_pad_size_; Shape high_pad_size_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Partition : public UnaryPrimitive { class Partition : public UnaryPrimitive {
@ -1630,8 +1492,6 @@ class Partition : public UnaryPrimitive {
private: private:
int kth_; int kth_;
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Power : public UnaryPrimitive { class Power : public UnaryPrimitive {
@ -1646,9 +1506,6 @@ class Power : public UnaryPrimitive {
DEFINE_PRINT(Power) DEFINE_PRINT(Power)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class QuantizedMatmul : public UnaryPrimitive { class QuantizedMatmul : public UnaryPrimitive {
@ -1679,8 +1536,6 @@ class QuantizedMatmul : public UnaryPrimitive {
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class GatherQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
@ -1706,8 +1561,6 @@ class GatherQMM : public UnaryPrimitive {
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class RandomBits : public UnaryPrimitive { class RandomBits : public UnaryPrimitive {
@ -1728,8 +1581,6 @@ class RandomBits : public UnaryPrimitive {
private: private:
Shape shape_; Shape shape_;
int width_; int width_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Real : public UnaryPrimitive { class Real : public UnaryPrimitive {
@ -1837,9 +1688,6 @@ class Round : public UnaryPrimitive {
DEFINE_PRINT(Round) DEFINE_PRINT(Round)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Scan : public UnaryPrimitive { class Scan : public UnaryPrimitive {
@ -1936,7 +1784,6 @@ class Scatter : public UnaryPrimitive {
}; };
private: private:
void eval(const std::vector<array>& inputs, array& out);
ReduceType reduce_type_; ReduceType reduce_type_;
std::vector<int> axes_; std::vector<int> axes_;
}; };
@ -1953,9 +1800,6 @@ class Sigmoid : public UnaryPrimitive {
DEFINE_PRINT(Sigmoid) DEFINE_PRINT(Sigmoid)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Sign : public UnaryPrimitive { class Sign : public UnaryPrimitive {
@ -1970,9 +1814,6 @@ class Sign : public UnaryPrimitive {
DEFINE_PRINT(Sign) DEFINE_PRINT(Sign)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Sin : public UnaryPrimitive { class Sin : public UnaryPrimitive {
@ -1987,9 +1828,6 @@ class Sin : public UnaryPrimitive {
DEFINE_PRINT(Sin) DEFINE_PRINT(Sin)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Sinh : public UnaryPrimitive { class Sinh : public UnaryPrimitive {
@ -2004,9 +1842,6 @@ class Sinh : public UnaryPrimitive {
DEFINE_PRINT(Sinh) DEFINE_PRINT(Sinh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Slice : public UnaryPrimitive { class Slice : public UnaryPrimitive {
@ -2036,7 +1871,6 @@ class Slice : public UnaryPrimitive {
Shape start_indices_; Shape start_indices_;
Shape end_indices_; Shape end_indices_;
Shape strides_; Shape strides_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
@ -2068,8 +1902,6 @@ class SliceUpdate : public UnaryPrimitive {
Shape start_indices_; Shape start_indices_;
Shape end_indices_; Shape end_indices_;
Shape strides_; Shape strides_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class DynamicSlice : public UnaryPrimitive { class DynamicSlice : public UnaryPrimitive {
@ -2136,7 +1968,6 @@ class Softmax : public UnaryPrimitive {
}; };
private: private:
void eval(const std::vector<array>& inputs, array& out);
bool precise_; bool precise_;
}; };
@ -2159,8 +1990,6 @@ class Sort : public UnaryPrimitive {
private: private:
int axis_; int axis_;
void eval(const std::vector<array>& inputs, array& out);
}; };
class Split : public Primitive { class Split : public Primitive {
@ -2200,9 +2029,6 @@ class Square : public UnaryPrimitive {
DEFINE_PRINT(Square) DEFINE_PRINT(Square)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Sqrt : public UnaryPrimitive { class Sqrt : public UnaryPrimitive {
@ -2230,7 +2056,6 @@ class Sqrt : public UnaryPrimitive {
} }
private: private:
void eval(const std::vector<array>& inputs, array& out);
bool recip_; bool recip_;
}; };
@ -2262,9 +2087,6 @@ class Subtract : public UnaryPrimitive {
DEFINE_PRINT(Subtract) DEFINE_PRINT(Subtract)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Squeeze : public UnaryPrimitive { class Squeeze : public UnaryPrimitive {
@ -2304,9 +2126,6 @@ class Tan : public UnaryPrimitive {
DEFINE_PRINT(Tan) DEFINE_PRINT(Tan)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Tanh : public UnaryPrimitive { class Tanh : public UnaryPrimitive {
@ -2321,9 +2140,6 @@ class Tanh : public UnaryPrimitive {
DEFINE_PRINT(Tanh) DEFINE_PRINT(Tanh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
}; };
class Unflatten : public UnaryPrimitive { class Unflatten : public UnaryPrimitive {
@ -2404,9 +2220,6 @@ class QRF : public Primitive {
override; override;
DEFINE_PRINT(QRF) DEFINE_PRINT(QRF)
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
/* SVD primitive. */ /* SVD primitive. */
@ -2421,9 +2234,6 @@ class SVD : public Primitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_PRINT(SVD) DEFINE_PRINT(SVD)
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
/* Matrix inversion primitive. */ /* Matrix inversion primitive. */
@ -2442,7 +2252,6 @@ class Inverse : public UnaryPrimitive {
} }
private: private:
void eval(const std::vector<array>& inputs, array& output);
bool tri_; bool tri_;
bool upper_; bool upper_;
}; };
@ -2462,7 +2271,6 @@ class Cholesky : public UnaryPrimitive {
DEFINE_PRINT(Cholesky) DEFINE_PRINT(Cholesky)
private: private:
void eval(const std::vector<array>& inputs, array& output);
bool upper_; bool upper_;
}; };
@ -2489,7 +2297,6 @@ class Eigh : public Primitive {
} }
private: private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::string uplo_; std::string uplo_;
bool compute_eigenvectors_; bool compute_eigenvectors_;
}; };

View File

@ -14,6 +14,7 @@ inline constexpr bool can_convert_to_complex128 =
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>; !std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
struct complex128_t : public std::complex<double> { struct complex128_t : public std::complex<double> {
complex128_t() : std::complex<double>() {};
complex128_t(double v, double u) : std::complex<double>(v, u) {}; complex128_t(double v, double u) : std::complex<double>(v, u) {};
complex128_t(std::complex<double> v) : std::complex<double>(v) {}; complex128_t(std::complex<double> v) : std::complex<double>(v) {};
@ -32,6 +33,7 @@ inline constexpr bool can_convert_to_complex64 =
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>; !std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
struct complex64_t : public std::complex<float> { struct complex64_t : public std::complex<float> {
complex64_t() : std::complex<float>() {};
complex64_t(float v, float u) : std::complex<float>(v, u) {}; complex64_t(float v, float u) : std::complex<float>(v, u) {};
complex64_t(std::complex<float> v) : std::complex<float>(v) {}; complex64_t(std::complex<float> v) : std::complex<float>(v) {};

View File

@ -1,11 +1,12 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#include <arm_fp16.h> #include <arm_fp16.h>
namespace mlx::core { namespace mlx::core {
typedef __fp16 float16_t; using ::float16_t;
} // namespace mlx::core } // namespace mlx::core
#else #else
@ -17,11 +18,12 @@ typedef struct _MLX_Float16 float16_t;
} // namespace mlx::core } // namespace mlx::core
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#ifdef __ARM_FEATURE_BF16 #ifdef __ARM_FEATURE_BF16
#include <arm_bf16.h> #include <arm_bf16.h>
namespace mlx::core { namespace mlx::core {
typedef __bf16 bfloat16_t; using ::bfloat16_t;
} // namespace mlx::core } // namespace mlx::core
#else #else

View File

@ -741,7 +741,7 @@ void init_array(nb::module_& m) {
[](const mx::array& a) { [](const mx::array& a) {
if (mx::issubdtype(a.dtype(), mx::inexact)) { if (mx::issubdtype(a.dtype(), mx::inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise inversion."); "Floating point types not allowed with bitwise inversion.");
} }
if (a.dtype() != mx::bool_) { if (a.dtype() != mx::bool_) {
throw std::invalid_argument( throw std::invalid_argument(
@ -791,7 +791,7 @@ void init_array(nb::module_& m) {
if (mx::issubdtype(a.dtype(), mx::inexact) || if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) { mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise or."); "Floating point types not allowed with bitwise or.");
} }
return mx::bitwise_or(a, b); return mx::bitwise_or(a, b);
}, },
@ -806,7 +806,7 @@ void init_array(nb::module_& m) {
if (mx::issubdtype(a.dtype(), mx::inexact) || if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) { mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise or."); "Floating point types not allowed with bitwise or.");
} }
a.overwrite_descriptor(mx::bitwise_or(a, b)); a.overwrite_descriptor(mx::bitwise_or(a, b));
return a; return a;
@ -838,7 +838,7 @@ void init_array(nb::module_& m) {
if (mx::issubdtype(a.dtype(), mx::inexact) || if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) { mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or left shift."); "Floating point types not allowed with left shift.");
} }
a.overwrite_descriptor(mx::left_shift(a, b)); a.overwrite_descriptor(mx::left_shift(a, b));
return a; return a;
@ -870,7 +870,7 @@ void init_array(nb::module_& m) {
if (mx::issubdtype(a.dtype(), mx::inexact) || if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) { mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or right shift."); "Floating point types not allowed with right shift.");
} }
a.overwrite_descriptor(mx::right_shift(a, b)); a.overwrite_descriptor(mx::right_shift(a, b));
return a; return a;

View File

@ -289,6 +289,9 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1]) self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
z = -mx.ones(64) % mx.full(64, 2)
self.assertTrue(mx.array_equal(z, mx.ones(64)))
def test_comparisons(self): def test_comparisons(self):
a = mx.array([0.0, 1.0, 5.0]) a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0])

View File

@ -207,8 +207,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N) x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M) w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1) x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2) w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits) w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(