mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends * more progress * simplify * add half type and allow infs in simd exp * unify softmax + quantized, more dispatches to simd quantized mm * add sin/cos, use simd in vector-scalar ops * faster CPU vectorize quant * faster erf/erfinv
This commit is contained in:
parent
7064fed1b1
commit
4758c8baa1
@ -147,6 +147,7 @@ if(MLX_BUILD_CPU)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
|
@ -3,6 +3,4 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
|
||||
|
@ -11,448 +11,8 @@
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::two:
|
||||
vvlog2f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::ten:
|
||||
vvlog10f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x * y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@ -484,120 +44,4 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
float minus_1 = -1;
|
||||
vDSP_vsmsa(
|
||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
float val = -(*s);
|
||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
int val = -(*s);
|
||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,117 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void _qmm_t_4_64(
|
||||
float* result,
|
||||
const float* x,
|
||||
const uint32_t* w,
|
||||
const float* scales,
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& scales = inputs[2];
|
||||
auto& biases = inputs[3];
|
||||
|
||||
bool condition =
|
||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,393 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
ipart = simd::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
// Avoid supressing NaNs
|
||||
simd_int16 eq = (x_init == x_init);
|
||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||
epart = vshlq_n_s16(epart, 10);
|
||||
|
||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding sum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
float16x4_t zero = vdup_n_f16(0);
|
||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpadd_f16(y, zero);
|
||||
y = vpadd_f16(y, zero);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
VT vnormalizer = ops.init(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
case complex64:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -5,6 +5,18 @@ else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
set(COMPILE_DEPS
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
simd/simd.h
|
||||
simd/base_simd.h
|
||||
simd/math.h
|
||||
simd/type.h
|
||||
unary_ops.h
|
||||
binary_ops.h)
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
@ -19,13 +31,8 @@ add_custom_command(
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
|
||||
${COMPILE_DEPS})
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
@ -60,6 +67,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
|
@ -61,7 +61,7 @@ void arg_reduce_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
@ -6,8 +6,8 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -15,69 +15,61 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
binary_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
binary_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
binary_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
binary_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
binary_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
binary_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
binary_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
binary_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
binary_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
binary_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
binary_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
void DivMod::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
@ -132,50 +124,68 @@ void DivMod::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
} else {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||
comparison_op(a, b, out, detail::Equal());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -196,54 +206,54 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
}
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -307,7 +317,7 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@ -122,16 +124,22 @@ void set_binary_op_output_data(
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
@ -140,14 +148,22 @@ struct DefaultVectorScalar {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
@ -156,13 +172,22 @@ struct DefaultScalarVector {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
@ -277,21 +302,8 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@ -303,19 +315,19 @@ void binary_op(
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -376,15 +388,39 @@ void binary_op(
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
@ -393,134 +429,52 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
binary_op<bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
98
mlx/backend/common/binary_ops.h
Normal file
98
mlx/backend/common/binary_ops.h
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
#define BINARY_SINGLE() \
|
||||
template <typename T> \
|
||||
T operator()(T x, T y) { \
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
|
||||
}
|
||||
|
||||
#define DEFAULT_BINARY_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
|
||||
return op(x, y); \
|
||||
} \
|
||||
BINARY_SINGLE() \
|
||||
};
|
||||
|
||||
DEFAULT_BINARY_OP(Add, operator+)
|
||||
DEFAULT_BINARY_OP(ArcTan2, atan2)
|
||||
DEFAULT_BINARY_OP(Divide, operator/)
|
||||
DEFAULT_BINARY_OP(Multiply, operator*)
|
||||
DEFAULT_BINARY_OP(Subtract, operator-)
|
||||
DEFAULT_BINARY_OP(LogicalAnd, operator&&)
|
||||
DEFAULT_BINARY_OP(LogicalOr, operator||)
|
||||
DEFAULT_BINARY_OP(BitwiseAnd, operator&)
|
||||
DEFAULT_BINARY_OP(BitwiseOr, operator|)
|
||||
DEFAULT_BINARY_OP(BitwiseXor, operator^)
|
||||
DEFAULT_BINARY_OP(LeftShift, operator<<)
|
||||
DEFAULT_BINARY_OP(RightShift, operator>>)
|
||||
DEFAULT_BINARY_OP(Remainder, remainder)
|
||||
DEFAULT_BINARY_OP(Maximum, maximum)
|
||||
DEFAULT_BINARY_OP(Minimum, minimum)
|
||||
DEFAULT_BINARY_OP(Power, pow)
|
||||
|
||||
#define DEFAULT_BOOL_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
|
||||
return op(x, y); \
|
||||
} \
|
||||
template <typename T> \
|
||||
bool operator()(T x, T y) { \
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
|
||||
} \
|
||||
};
|
||||
|
||||
DEFAULT_BOOL_OP(Equal, operator==)
|
||||
DEFAULT_BOOL_OP(Greater, operator>)
|
||||
DEFAULT_BOOL_OP(GreaterEqual, operator>=)
|
||||
DEFAULT_BOOL_OP(Less, operator<)
|
||||
DEFAULT_BOOL_OP(LessEqual, operator<=)
|
||||
DEFAULT_BOOL_OP(NotEqual, operator!=)
|
||||
|
||||
struct NaNEqual {
|
||||
template <int N, typename T>
|
||||
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
|
||||
auto maxval = maximum(x, y);
|
||||
auto minval = minimum(x, y);
|
||||
auto mask = minval == -inf || maxval == inf;
|
||||
auto out = maxval + log1p(exp(minval - maxval));
|
||||
return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
|
||||
}
|
||||
BINARY_SINGLE()
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
|
||||
.value;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
|
||||
return select(condition, x, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
@ -64,7 +64,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
|
@ -5,7 +5,8 @@
|
||||
// clang-format off
|
||||
#include "mlx/types/half_types.h"
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
// clang-format on
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -23,6 +24,7 @@ template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
size_t size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
|
@ -21,98 +21,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
DEFAULT(Add)
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArcCos)
|
||||
DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(Multiply)
|
||||
DEFAULT(Negative)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -45,7 +45,9 @@ void ssyevd(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void Eigh::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||
void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
std::vector<std::ptrdiff_t> strides_in(
|
||||
in.strides().begin(), in.strides().end());
|
||||
|
@ -82,7 +82,7 @@ void hadamard(array& out, int n, int m, float scale) {
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
|
@ -162,7 +162,7 @@ void dispatch_gather(
|
||||
}
|
||||
}
|
||||
|
||||
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
@ -337,7 +337,7 @@ void dispatch_scatter(
|
||||
}
|
||||
}
|
||||
|
||||
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
|
@ -110,7 +110,7 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
|
@ -11,7 +11,7 @@
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
|
@ -1,12 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -51,11 +48,4 @@ void load(
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -53,7 +53,7 @@ inline void mask_matrix(
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
@ -210,7 +210,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
|
@ -1,680 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
namespace {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
} // namespace
|
||||
|
||||
typedef union {
|
||||
int i;
|
||||
float f;
|
||||
} IntOrFloat;
|
||||
|
||||
inline float fast_exp(float x) {
|
||||
if (x == -std::numeric_limits<float>::infinity()) {
|
||||
return 0.0f;
|
||||
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
float ipart, fpart;
|
||||
IntOrFloat epart;
|
||||
x = std::max(-80.f, std::min(x, 80.f));
|
||||
ipart = std::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart.i = (int(ipart) + 127) << 23;
|
||||
|
||||
return epart.f * x;
|
||||
}
|
||||
|
||||
inline float fast_erf(float a) {
|
||||
float r, s, t, u;
|
||||
t = std::abs(a);
|
||||
s = a * a;
|
||||
if (t > 0.927734375f) {
|
||||
// maximum error 0.99527 ulp
|
||||
r = std::fma(
|
||||
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||
u = std::fma(
|
||||
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||
r = std::fma(r, s, u);
|
||||
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||
r = std::fma(r, t, -t);
|
||||
// TODO, replace with expm1 when implemented
|
||||
r = 1.0f - std::exp(r);
|
||||
r = std::copysign(r, a);
|
||||
} else {
|
||||
// maximum error 0.98929 ulp
|
||||
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||
r = std::fma(r, a, a);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
inline float fast_erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::imag(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::real(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::rint(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {std::rint(x.real()), std::rint(x.imag())};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto one = static_cast<decltype(x)>(1.0);
|
||||
return one / (one + fast_exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = numerator % denominator;
|
||||
if (r != 0 && (r < 0 != denominator < 0))
|
||||
r += denominator;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = std::fmod(numerator, denominator);
|
||||
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||
r += denominator;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// isnan always returns false for integers, and MSVC refuses to compile.
|
||||
return x == y;
|
||||
} else {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return (x > y) ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x > y) ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto maxval = Maximum()(x, y);
|
||||
auto minval = Minimum()(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return std::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
@ -9,10 +9,9 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -58,112 +57,64 @@ int64_t compute_dynamic_offset(
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Abs());
|
||||
}
|
||||
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Copy::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void CustomTransforms::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void Depends::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Split::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
}
|
||||
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccos] Cannot compute inverse cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsin] Cannot compute inverse sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan] Cannot compute inverse tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
@ -187,17 +138,6 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_fp(in, out, detail::Conjugate());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@ -209,94 +149,6 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cos] Cannot compute cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
@ -305,18 +157,7 @@ void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
@ -331,57 +172,14 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
}
|
||||
|
||||
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
@ -412,7 +210,7 @@ void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
@ -460,71 +258,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sigmoid] Cannot sigmoid of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sin] Cannot compute sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
@ -596,7 +333,7 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
@ -632,46 +369,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
}
|
||||
|
||||
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tan] Cannot compute tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -149,7 +149,9 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
allocator::free(tau);
|
||||
}
|
||||
|
||||
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void QRF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[QRF::eval] only supports float32.");
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -151,6 +151,78 @@ void _qmm_t(
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits, int S>
|
||||
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
simd::Simd<uint32_t, S> wi;
|
||||
if constexpr (bits == 4 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else if constexpr (bits == 8 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto l = simd::Simd<uint32_t, 4>(*w++);
|
||||
auto r = simd::Simd<uint32_t, 4>(*w);
|
||||
wi = simd::Simd<uint32_t, S>(l, r);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
return wi;
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
|
||||
w_local += packs_per_simd;
|
||||
wf = wf * scale;
|
||||
wf = wf + bias;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
acc = acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_dispatch_transpose(
|
||||
T* result,
|
||||
@ -163,9 +235,14 @@ void _qmm_dispatch_transpose(
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
|
||||
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
} else {
|
||||
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,13 +326,13 @@ void _qmm_dispatch(
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
||||
int batch_size = x.size() / (K * M);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
@ -384,7 +461,7 @@ void _bs_qmm_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -61,7 +62,7 @@ void select_op(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Select::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
|
56
mlx/backend/common/simd/accelerate_fp16_simd.h
Normal file
56
mlx/backend/common/simd/accelerate_fp16_simd.h
Normal file
@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION < 6
|
||||
#include "mlx/backend/common/simd/neon_fp16_simd.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION >= 6
|
||||
constexpr int N = 8;
|
||||
template <int N>
|
||||
struct ScalarT<float16_t, N> {
|
||||
using v = _Float16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
|
||||
Simd<float, N> in = v; \
|
||||
return op(in); \
|
||||
}
|
||||
|
||||
SIMD_FP16_DEFAULT_UNARY(acos)
|
||||
SIMD_FP16_DEFAULT_UNARY(acosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(asin)
|
||||
SIMD_FP16_DEFAULT_UNARY(asinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(atan)
|
||||
SIMD_FP16_DEFAULT_UNARY(atanh)
|
||||
SIMD_FP16_DEFAULT_UNARY(cosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(expm1)
|
||||
SIMD_FP16_DEFAULT_UNARY(log)
|
||||
SIMD_FP16_DEFAULT_UNARY(log2)
|
||||
SIMD_FP16_DEFAULT_UNARY(log10)
|
||||
SIMD_FP16_DEFAULT_UNARY(log1p)
|
||||
SIMD_FP16_DEFAULT_UNARY(sinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(tan)
|
||||
SIMD_FP16_DEFAULT_UNARY(tanh)
|
||||
|
||||
#define SIMD_FP16_DEFAULT_BINARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
|
||||
Simd<float, N> a = x; \
|
||||
Simd<float, N> b = y; \
|
||||
return op(a, b); \
|
||||
}
|
||||
SIMD_FP16_DEFAULT_BINARY(atan2)
|
||||
SIMD_FP16_DEFAULT_BINARY(remainder)
|
||||
SIMD_FP16_DEFAULT_BINARY(pow)
|
||||
|
||||
} // namespace mlx::core::simd
|
291
mlx/backend/common/simd/accelerate_simd.h
Normal file
291
mlx/backend/common/simd/accelerate_simd.h
Normal file
@ -0,0 +1,291 @@
|
||||
#pragma once
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
|
||||
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__TV_OS_VERSION_MIN_REQUIRED >= 180000
|
||||
#define MLX_SIMD_LIBRARY_VERSION 6
|
||||
#else
|
||||
#define MLX_SIMD_LIBRARY_VERSION 5
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
// Apple simd namespace
|
||||
namespace asd = ::simd;
|
||||
|
||||
// This indirection is needed to remap certain types to ones that accelerate
|
||||
// SIMD can handle
|
||||
template <typename T, int N>
|
||||
struct ScalarT {
|
||||
using v = T;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<bool, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int8_t, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<uint64_t, N> {
|
||||
using v = unsigned long;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int64_t, N> {
|
||||
using v = long;
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct Simd {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = typename ScalarT<T, N>::v;
|
||||
|
||||
Simd<T, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(U v) : value(v){};
|
||||
|
||||
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
|
||||
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
|
||||
x.value, y.value);
|
||||
};
|
||||
|
||||
T operator[](int idx) const {
|
||||
return reinterpret_cast<const T*>(&value)[idx];
|
||||
}
|
||||
|
||||
T& operator[](int idx) {
|
||||
return reinterpret_cast<T*>(&value)[idx];
|
||||
}
|
||||
|
||||
typename asd::Vector<scalar_t, N>::packed_t value;
|
||||
};
|
||||
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
Simd<T, N> name(Simd<T, N> v) { \
|
||||
return op(v.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_UNARY(abs, asd::abs)
|
||||
SIMD_DEFAULT_UNARY(floor, asd::floor)
|
||||
SIMD_DEFAULT_UNARY(acos, asd::acos)
|
||||
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
|
||||
SIMD_DEFAULT_UNARY(asin, asd::asin)
|
||||
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
|
||||
SIMD_DEFAULT_UNARY(atan, asd::atan)
|
||||
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
|
||||
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
|
||||
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
|
||||
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
|
||||
SIMD_DEFAULT_UNARY(log, asd::log)
|
||||
SIMD_DEFAULT_UNARY(log2, asd::log2)
|
||||
SIMD_DEFAULT_UNARY(log10, asd::log10)
|
||||
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
|
||||
SIMD_DEFAULT_UNARY(rint, asd::rint)
|
||||
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
|
||||
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
|
||||
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
|
||||
SIMD_DEFAULT_UNARY(recip, asd::recip)
|
||||
SIMD_DEFAULT_UNARY(tan, asd::tan)
|
||||
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> operator-(Simd<T, N> v) {
|
||||
return -v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> isnan(Simd<T, N> v) {
|
||||
return asd::convert<char>(v.value != v.value);
|
||||
}
|
||||
|
||||
// No simd_boolN in accelerate, use int8_t instead
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> operator!(Simd<T, N> v) {
|
||||
return asd::convert<char>(!v.value);
|
||||
}
|
||||
|
||||
#define SIMD_DEFAULT_BINARY(OP) \
|
||||
template <typename T, typename U, int N> \
|
||||
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
|
||||
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_BINARY(+)
|
||||
SIMD_DEFAULT_BINARY(-)
|
||||
SIMD_DEFAULT_BINARY(/)
|
||||
SIMD_DEFAULT_BINARY(*)
|
||||
SIMD_DEFAULT_BINARY(<<)
|
||||
SIMD_DEFAULT_BINARY(>>)
|
||||
SIMD_DEFAULT_BINARY(|)
|
||||
SIMD_DEFAULT_BINARY(^)
|
||||
SIMD_DEFAULT_BINARY(&)
|
||||
SIMD_DEFAULT_BINARY(&&)
|
||||
SIMD_DEFAULT_BINARY(||)
|
||||
|
||||
#define SIMD_DEFAULT_COMPARISONS(OP) \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
|
||||
return asd::convert<char>(a.value OP b); \
|
||||
} \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
|
||||
return asd::convert<char>(a OP b.value); \
|
||||
} \
|
||||
template <int N, typename T1, typename T2> \
|
||||
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
|
||||
return asd::convert<char>(a.value OP b.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_COMPARISONS(>)
|
||||
SIMD_DEFAULT_COMPARISONS(<)
|
||||
SIMD_DEFAULT_COMPARISONS(>=)
|
||||
SIMD_DEFAULT_COMPARISONS(<=)
|
||||
SIMD_DEFAULT_COMPARISONS(==)
|
||||
SIMD_DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
return asd::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::max(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::min(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
Simd<T, N> r;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
r = asd::remainder(a.value, b.value);
|
||||
} else {
|
||||
r = a - b * (a / b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
auto mask = r != 0 && (r < 0 != b < 0);
|
||||
r = select(mask, r + b, r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 4) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
|
||||
} else {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
|
||||
return asd::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N>
|
||||
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
|
||||
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
|
||||
template <typename T, int N>
|
||||
bool any(Simd<T, N> x) {
|
||||
return asd::any(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T sum(Simd<T, N> x) {
|
||||
return asd::reduce_add(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T max(Simd<T, N> x) {
|
||||
return asd::reduce_max(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T min(Simd<T, N> x) {
|
||||
return asd::reduce_min(x.value);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include "mlx/backend/common/simd/accelerate_fp16_simd.h"
|
||||
#endif
|
252
mlx/backend/common/simd/base_simd.h
Normal file
252
mlx/backend/common/simd/base_simd.h
Normal file
@ -0,0 +1,252 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
namespace mlx::core::simd {
|
||||
template <typename T, int N>
|
||||
struct Simd;
|
||||
|
||||
template <typename T>
|
||||
static constexpr int max_size = 1;
|
||||
|
||||
template <typename T>
|
||||
struct Simd<T, 1> {
|
||||
static constexpr int size = 1;
|
||||
T value;
|
||||
Simd() {}
|
||||
template <typename U>
|
||||
Simd(Simd<U, 1> v) : value(v.value) {}
|
||||
template <typename U>
|
||||
Simd(U v) : value(v) {}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> load(const T* x) {
|
||||
return *(Simd<T, N>*)x;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
void store(T* dst, Simd<T, N> x) {
|
||||
// Maintain invariant that bool is either 0 or 1 as
|
||||
// simd comparison ops set all bits in the result to 1
|
||||
if constexpr (std::is_same_v<T, bool> && N > 1) {
|
||||
x = x & 1;
|
||||
}
|
||||
*(Simd<T, N>*)dst = x;
|
||||
}
|
||||
|
||||
template <typename, typename = void>
|
||||
constexpr bool is_complex = false;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
|
||||
true;
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rint(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{
|
||||
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::rint(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rsqrt(Simd<T, 1> in) {
|
||||
return T(1.0) / sqrt(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> recip(Simd<T, 1> in) {
|
||||
return T(1.0) / in;
|
||||
}
|
||||
|
||||
#define DEFAULT_UNARY(name, op) \
|
||||
template <typename T> \
|
||||
Simd<T, 1> name(Simd<T, 1> in) { \
|
||||
return op(in.value); \
|
||||
}
|
||||
|
||||
DEFAULT_UNARY(operator-, std::negate{})
|
||||
DEFAULT_UNARY(operator!, std::logical_not{})
|
||||
DEFAULT_UNARY(abs, std::abs)
|
||||
DEFAULT_UNARY(acos, std::acos)
|
||||
DEFAULT_UNARY(acosh, std::acosh)
|
||||
DEFAULT_UNARY(asin, std::asin)
|
||||
DEFAULT_UNARY(asinh, std::asinh)
|
||||
DEFAULT_UNARY(atan, std::atan)
|
||||
DEFAULT_UNARY(atanh, std::atanh)
|
||||
DEFAULT_UNARY(ceil, std::ceil)
|
||||
DEFAULT_UNARY(conj, std::conj)
|
||||
DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
|
||||
return std::real(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
|
||||
return std::imag(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, 1> isnan(Simd<T, 1> in) {
|
||||
return std::isnan(in.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_BINARY(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
|
||||
->Simd<decltype(a.value OP b.value), 1> { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_BINARY(+)
|
||||
DEFAULT_BINARY(-)
|
||||
DEFAULT_BINARY(*)
|
||||
DEFAULT_BINARY(/)
|
||||
DEFAULT_BINARY(<<)
|
||||
DEFAULT_BINARY(>>)
|
||||
DEFAULT_BINARY(|)
|
||||
DEFAULT_BINARY(^)
|
||||
DEFAULT_BINARY(&)
|
||||
DEFAULT_BINARY(&&)
|
||||
DEFAULT_BINARY(||)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
T r;
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
r = a % b;
|
||||
} else {
|
||||
r = std::remainder(a, b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
if (r != 0 && (r < 0 != b < 0)) {
|
||||
r += b;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
T base = a.value;
|
||||
T exp = b.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return std::pow(base, exp);
|
||||
} else {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
return std::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_COMPARISONS(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_COMPARISONS(>)
|
||||
DEFAULT_COMPARISONS(<)
|
||||
DEFAULT_COMPARISONS(>=)
|
||||
DEFAULT_COMPARISONS(<=)
|
||||
DEFAULT_COMPARISONS(==)
|
||||
DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename MaskT, typename T>
|
||||
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
|
||||
return mask.value ? x.value : y.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
|
||||
return std::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
|
||||
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
#define DEFAULT_REDUCTION(name, type) \
|
||||
template <typename T> \
|
||||
type name(Simd<T, 1> x) { \
|
||||
return x.value; \
|
||||
}
|
||||
|
||||
DEFAULT_REDUCTION(max, T)
|
||||
DEFAULT_REDUCTION(min, T)
|
||||
DEFAULT_REDUCTION(sum, T)
|
||||
DEFAULT_REDUCTION(any, bool)
|
||||
DEFAULT_REDUCTION(all, bool)
|
||||
|
||||
} // namespace mlx::core::simd
|
193
mlx/backend/common/simd/math.h
Normal file
193
mlx/backend/common/simd/math.h
Normal file
@ -0,0 +1,193 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/type.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
template <typename T, int N>
|
||||
Simd<T, N> exp(Simd<T, N> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{std::exp(in.value)};
|
||||
} else {
|
||||
Simd<float, N> x_init = in;
|
||||
auto x = x_init * 1.442695f; // multiply with log_2(e)
|
||||
Simd<float, N> ipart, fpart;
|
||||
ipart = floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = fma(x, fpart, 1.339887440266574e-3f);
|
||||
x = fma(x, fpart, 9.618437357674640e-3f);
|
||||
x = fma(x, fpart, 5.550332471162809e-2f);
|
||||
x = fma(x, fpart, 2.402264791363012e-1f);
|
||||
x = fma(x, fpart, 6.931472028550421e-1f);
|
||||
x = fma(x, fpart, 1.000000000000000f);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
|
||||
|
||||
// Deal with NaN and Inf
|
||||
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
|
||||
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
|
||||
result = select(x_init < -88.0f, Simd<float, N>(0), result);
|
||||
return Simd<T, N>(result);
|
||||
}
|
||||
}
|
||||
|
||||
/* Implementation from:
|
||||
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
|
||||
* which originally came from the Cephes math library.
|
||||
*/
|
||||
template <bool Sine, typename T, int N>
|
||||
Simd<T, N> sincos(Simd<T, N> in) {
|
||||
auto sign_mask_sin = in < 0;
|
||||
in = abs(in);
|
||||
Simd<float, N> x = in;
|
||||
|
||||
// scale by 4/Pi
|
||||
auto y = x * 1.27323954473516f;
|
||||
|
||||
// store the integer part of y in mm0
|
||||
Simd<uint32_t, N> emm2 = y;
|
||||
|
||||
// j=(j+1) & (~1) (see the cephes sources)
|
||||
emm2 = emm2 + 1;
|
||||
emm2 = emm2 & ~1;
|
||||
|
||||
y = emm2;
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
x = fma(y, Simd<float, N>(-0.78515625f), x);
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
auto z = x * x;
|
||||
|
||||
auto y1 =
|
||||
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
|
||||
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
|
||||
y1 = fma(y1, z, 4.166664568298827e-2f);
|
||||
y2 = fma(y2, z, -1.6666654611e-1f);
|
||||
y1 = y1 * z;
|
||||
y2 = y2 * z;
|
||||
y1 = y1 * z;
|
||||
y2 = fma(x, y2, x);
|
||||
y1 = fma(z, Simd<float, N>(-0.5f), y1);
|
||||
y1 = y1 + 1.0f;
|
||||
|
||||
if constexpr (Sine) {
|
||||
auto ys = select(poly_mask, y1, y2);
|
||||
return select(sign_mask_sin, -ys, ys);
|
||||
} else {
|
||||
auto yc = select(poly_mask, y2, y1);
|
||||
return select(sign_mask_cos, yc, -yc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> sin(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::sin(x.value);
|
||||
} else {
|
||||
return sincos<true>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> cos(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::cos(x.value);
|
||||
} else {
|
||||
return sincos<false>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erf(Simd<T, N> x) {
|
||||
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
|
||||
Simd<float, N> v = x;
|
||||
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
|
||||
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
|
||||
r = fma(r, t, 1.421413741f);
|
||||
r = fma(r, t, -0.284496736f);
|
||||
r = fma(r, t, 0.254829592f);
|
||||
auto e = -exp(-v * v);
|
||||
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
|
||||
return select(x > 0, result, -result);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erfinv(Simd<T, N> a_) {
|
||||
Simd<float, N> a = a_;
|
||||
auto t = fma(a, 0.0f - a, 1.0f);
|
||||
t = log(t);
|
||||
auto lhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
};
|
||||
auto rhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
};
|
||||
auto thresh = 6.125f;
|
||||
// Compute both branches and select if N > 1
|
||||
if constexpr (N == 1) {
|
||||
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
|
||||
return a * lhs(t);
|
||||
} else { // maximum ulp error = 2.35002
|
||||
return a * rhs(t);
|
||||
}
|
||||
} else {
|
||||
return a * select(t > thresh, lhs(t), rhs(t));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
204
mlx/backend/common/simd/neon_fp16_simd.h
Normal file
204
mlx/backend/common/simd/neon_fp16_simd.h
Normal file
@ -0,0 +1,204 @@
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr int N = 8;
|
||||
|
||||
template <>
|
||||
struct Simd<float16_t, N> {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = float16_t;
|
||||
|
||||
Simd<float16_t, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
|
||||
|
||||
Simd<float16_t, N>(float16x8_t v) : value(v){};
|
||||
|
||||
Simd<float16_t, N>(Simd<float, N> other) {
|
||||
auto f32x4_a = *(float32x4_t*)(&other);
|
||||
auto f32x4_b = *((float32x4_t*)(&other) + 1);
|
||||
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
|
||||
};
|
||||
|
||||
Simd<float16_t, N>(Simd<uint16_t, N> other) {
|
||||
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
|
||||
};
|
||||
|
||||
operator Simd<int16_t, N>() {
|
||||
auto v = vcvtq_s16_f16(value);
|
||||
return load<int16_t, N>((int16_t*)&v);
|
||||
};
|
||||
|
||||
operator Simd<float, N>() {
|
||||
float32x4x2_t v;
|
||||
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
|
||||
v.val[1] = vcvt_high_f32_f16(value);
|
||||
return load<float, N>((float*)&v);
|
||||
}
|
||||
float16_t operator[](int idx) const {
|
||||
return reinterpret_cast<const float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16_t& operator[](int idx) {
|
||||
return reinterpret_cast<float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16x8_t value;
|
||||
};
|
||||
|
||||
#define DEFINE_NEON_UNARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
|
||||
return Simd<float16_t, N>{op(a.value)}; \
|
||||
}
|
||||
|
||||
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
|
||||
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
|
||||
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
|
||||
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
|
||||
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
|
||||
|
||||
#define DEFINE_NEON_BINARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
return op(a.value, b.value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
|
||||
return op(a.value, Simd<float16_t, N>(b).value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
|
||||
return op(Simd<float16_t, N>(a).value, b.value); \
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
|
||||
auto out = vceqzq_f16(v.value);
|
||||
return Simd<uint16_t, N>(*(uint16_t*)&out);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
|
||||
return vnegq_f16(v.value);
|
||||
}
|
||||
|
||||
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
|
||||
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
|
||||
|
||||
#define DEFINE_NEON_COMPARISON(Op, op) \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
|
||||
auto out = op(a.value, Simd<float16_t, N>(b).value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
|
||||
auto out = op(Simd<float16_t, N>(a).value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
inline Simd<bool, N> operator Op( \
|
||||
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
auto out = op(a.value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
}
|
||||
|
||||
DEFINE_NEON_COMPARISON(==, vceqq_f16)
|
||||
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
|
||||
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
|
||||
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
|
||||
DEFINE_NEON_COMPARISON(<, vcltq_f16)
|
||||
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
|
||||
return !(a == b);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator||(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
inline Simd<float16_t, N> operator&&(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
|
||||
return v != v;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<float16_t, N>
|
||||
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
|
||||
return minimum(maximum(v, min), max);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
|
||||
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
|
||||
}
|
||||
|
||||
template <typename MaskT>
|
||||
Simd<float16_t, N>
|
||||
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
|
||||
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
inline float16_t max(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t min(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmin_f16(y, y);
|
||||
y = vpmin_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t sum(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpadd_f16(y, y);
|
||||
y = vpadd_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
4
mlx/backend/common/simd/simd.h
Normal file
4
mlx/backend/common/simd/simd.h
Normal file
@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/math.h"
|
||||
#include "mlx/backend/common/simd/type.h"
|
7
mlx/backend/common/simd/type.h
Normal file
7
mlx/backend/common/simd/type.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include "mlx/backend/common/simd/accelerate_simd.h"
|
||||
#endif
|
@ -4,61 +4,107 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int N = in.shape().back();
|
||||
int M = in.data_size() / N;
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
AccT normalizer = 0;
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
@ -97,7 +143,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
|
@ -287,7 +287,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@ -321,7 +321,7 @@ void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@ -355,7 +355,7 @@ void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@ -389,7 +389,7 @@ void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Partition::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
|
@ -137,7 +137,9 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
}
|
||||
}
|
||||
|
||||
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
||||
}
|
||||
|
@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
285
mlx/backend/common/unary.cpp
Normal file
285
mlx/backend/common/unary.cpp
Normal file
@ -0,0 +1,285 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto op = detail::Abs{};
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(in, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(in, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(in, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(in, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(in, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cos());
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
}
|
||||
|
||||
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Exp());
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
}
|
||||
|
||||
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
}
|
||||
|
||||
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
}
|
||||
|
||||
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sin());
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tan());
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -38,8 +39,19 @@ void unary_op(const array& a, array& out, Op op) {
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
dst += N;
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
108
mlx/backend/common/unary_ops.h
Normal file
108
mlx/backend/common/unary_ops.h
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
#define SINGLE() \
|
||||
template <typename T> \
|
||||
T operator()(T x) { \
|
||||
return (*this)(Simd<T, 1>(x)).value; \
|
||||
}
|
||||
|
||||
#define DEFAULT_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<T, N> operator()(Simd<T, N> x) { \
|
||||
return simd::op(x); \
|
||||
} \
|
||||
SINGLE() \
|
||||
};
|
||||
|
||||
DEFAULT_OP(Abs, abs)
|
||||
DEFAULT_OP(ArcCos, acos)
|
||||
DEFAULT_OP(ArcCosh, acosh)
|
||||
DEFAULT_OP(ArcSin, asin)
|
||||
DEFAULT_OP(ArcSinh, asinh)
|
||||
DEFAULT_OP(ArcTan, atan)
|
||||
DEFAULT_OP(ArcTanh, atanh)
|
||||
DEFAULT_OP(Ceil, ceil)
|
||||
DEFAULT_OP(Conjugate, conj)
|
||||
DEFAULT_OP(Cos, cos)
|
||||
DEFAULT_OP(Cosh, cosh)
|
||||
DEFAULT_OP(Erf, erf)
|
||||
DEFAULT_OP(ErfInv, erfinv)
|
||||
DEFAULT_OP(Exp, exp)
|
||||
DEFAULT_OP(Expm1, expm1)
|
||||
DEFAULT_OP(Floor, floor);
|
||||
DEFAULT_OP(Log, log);
|
||||
DEFAULT_OP(Log2, log2);
|
||||
DEFAULT_OP(Log10, log10);
|
||||
DEFAULT_OP(Log1p, log1p);
|
||||
DEFAULT_OP(LogicalNot, operator!)
|
||||
DEFAULT_OP(Negative, operator-)
|
||||
DEFAULT_OP(Round, rint);
|
||||
DEFAULT_OP(Sin, sin)
|
||||
DEFAULT_OP(Sinh, sinh)
|
||||
DEFAULT_OP(Sqrt, sqrt)
|
||||
DEFAULT_OP(Rsqrt, rsqrt)
|
||||
DEFAULT_OP(Tan, tan)
|
||||
DEFAULT_OP(Tanh, tanh)
|
||||
|
||||
struct Imag {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
||||
return simd::imag(x);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
||||
return simd::real(x);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return 1.0f / (1.0f + simd::exp(-x));
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return x * x;
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
193
mlx/primitives.h
193
mlx/primitives.h
@ -163,9 +163,6 @@ class Abs : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Abs)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Add : public UnaryPrimitive {
|
||||
@ -180,9 +177,6 @@ class Add : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Add)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class AddMM : public UnaryPrimitive {
|
||||
@ -226,8 +220,6 @@ class Arange : public UnaryPrimitive {
|
||||
double start_;
|
||||
double stop_;
|
||||
double step_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcCos : public UnaryPrimitive {
|
||||
@ -242,9 +234,6 @@ class ArcCos : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcCos)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcCosh : public UnaryPrimitive {
|
||||
@ -259,9 +248,6 @@ class ArcCosh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcCosh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcSin : public UnaryPrimitive {
|
||||
@ -276,9 +262,6 @@ class ArcSin : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcSin)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcSinh : public UnaryPrimitive {
|
||||
@ -293,9 +276,6 @@ class ArcSinh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcSinh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTan : public UnaryPrimitive {
|
||||
@ -310,9 +290,6 @@ class ArcTan : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcTan)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTan2 : public UnaryPrimitive {
|
||||
@ -327,9 +304,6 @@ class ArcTan2 : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcTan2)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTanh : public UnaryPrimitive {
|
||||
@ -344,9 +318,6 @@ class ArcTanh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArcTanh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArgPartition : public UnaryPrimitive {
|
||||
@ -369,8 +340,6 @@ class ArgPartition : public UnaryPrimitive {
|
||||
private:
|
||||
int kth_;
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArgReduce : public UnaryPrimitive {
|
||||
@ -398,8 +367,6 @@ class ArgReduce : public UnaryPrimitive {
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArgSort : public UnaryPrimitive {
|
||||
@ -420,8 +387,6 @@ class ArgSort : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class AsType : public UnaryPrimitive {
|
||||
@ -443,8 +408,6 @@ class AsType : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class AsStrided : public UnaryPrimitive {
|
||||
@ -518,8 +481,6 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
int block_size_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class GatherMM : public UnaryPrimitive {
|
||||
@ -537,9 +498,6 @@ class GatherMM : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(GatherMM)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
@ -603,9 +561,6 @@ class Ceil : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Ceil)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Compiled : public Primitive {
|
||||
@ -669,8 +624,6 @@ class Concatenate : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Conjugate : public UnaryPrimitive {
|
||||
@ -684,9 +637,6 @@ class Conjugate : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Conjugate)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Contiguous : public UnaryPrimitive {
|
||||
@ -787,9 +737,6 @@ class Cos : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Cos)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Cosh : public UnaryPrimitive {
|
||||
@ -804,9 +751,6 @@ class Cosh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Cosh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class CustomTransforms : public Primitive {
|
||||
@ -894,9 +838,6 @@ class Divide : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Divide)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class DivMod : public Primitive {
|
||||
@ -915,9 +856,6 @@ class DivMod : public Primitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
class Select : public UnaryPrimitive {
|
||||
@ -932,9 +870,6 @@ class Select : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Select)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Remainder : public UnaryPrimitive {
|
||||
@ -949,9 +884,6 @@ class Remainder : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Remainder)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Equal : public UnaryPrimitive {
|
||||
@ -979,7 +911,6 @@ class Equal : public UnaryPrimitive {
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool equal_nan_;
|
||||
};
|
||||
|
||||
@ -995,9 +926,6 @@ class Erf : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Erf)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ErfInv : public UnaryPrimitive {
|
||||
@ -1012,9 +940,6 @@ class ErfInv : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ErfInv)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Exp : public UnaryPrimitive {
|
||||
@ -1029,9 +954,6 @@ class Exp : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Exp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Expm1 : public UnaryPrimitive {
|
||||
@ -1045,9 +967,6 @@ class Expm1 : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Expm1)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ExpandDims : public UnaryPrimitive {
|
||||
@ -1100,8 +1019,6 @@ class FFT : public UnaryPrimitive {
|
||||
std::vector<size_t> axes_;
|
||||
bool inverse_;
|
||||
bool real_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Flatten : public UnaryPrimitive {
|
||||
@ -1141,9 +1058,6 @@ class Floor : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Floor)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Full : public UnaryPrimitive {
|
||||
@ -1157,9 +1071,6 @@ class Full : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Full)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Gather : public UnaryPrimitive {
|
||||
@ -1182,7 +1093,6 @@ class Gather : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> axes_;
|
||||
Shape slice_sizes_;
|
||||
};
|
||||
@ -1199,9 +1109,6 @@ class Greater : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Greater)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class GreaterEqual : public UnaryPrimitive {
|
||||
@ -1216,9 +1123,6 @@ class GreaterEqual : public UnaryPrimitive {
|
||||
DEFINE_PRINT(GreaterEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Hadamard : public UnaryPrimitive {
|
||||
@ -1241,8 +1145,6 @@ class Hadamard : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
float scale_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Imag : public UnaryPrimitive {
|
||||
@ -1271,9 +1173,6 @@ class Less : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Less)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LessEqual : public UnaryPrimitive {
|
||||
@ -1288,9 +1187,6 @@ class LessEqual : public UnaryPrimitive {
|
||||
DEFINE_PRINT(LessEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Load : public UnaryPrimitive {
|
||||
@ -1319,7 +1215,6 @@ class Load : public UnaryPrimitive {
|
||||
static Stream io_stream = new_stream(Device::cpu);
|
||||
return io_stream;
|
||||
};
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::shared_ptr<io::Reader> reader_;
|
||||
size_t offset_;
|
||||
bool swap_endianness_;
|
||||
@ -1360,7 +1255,6 @@ class Log : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
Base base_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Log1p : public UnaryPrimitive {
|
||||
@ -1374,9 +1268,6 @@ class Log1p : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Log1p)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogicalNot : public UnaryPrimitive {
|
||||
@ -1391,9 +1282,6 @@ class LogicalNot : public UnaryPrimitive {
|
||||
DEFINE_PRINT(LogicalNot)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogicalAnd : public UnaryPrimitive {
|
||||
@ -1408,9 +1296,6 @@ class LogicalAnd : public UnaryPrimitive {
|
||||
DEFINE_PRINT(LogicalAnd)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogicalOr : public UnaryPrimitive {
|
||||
@ -1425,9 +1310,6 @@ class LogicalOr : public UnaryPrimitive {
|
||||
DEFINE_PRINT(LogicalOr)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogAddExp : public UnaryPrimitive {
|
||||
@ -1442,9 +1324,6 @@ class LogAddExp : public UnaryPrimitive {
|
||||
DEFINE_PRINT(LogAddExp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Matmul : public UnaryPrimitive {
|
||||
@ -1473,9 +1352,6 @@ class Maximum : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Maximum)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Minimum : public UnaryPrimitive {
|
||||
@ -1490,9 +1366,6 @@ class Minimum : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Minimum)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Multiply : public UnaryPrimitive {
|
||||
@ -1507,9 +1380,6 @@ class Multiply : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Multiply)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Negative : public UnaryPrimitive {
|
||||
@ -1524,9 +1394,6 @@ class Negative : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Negative)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class NotEqual : public UnaryPrimitive {
|
||||
@ -1541,9 +1408,6 @@ class NotEqual : public UnaryPrimitive {
|
||||
DEFINE_PRINT(NotEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class NumberOfElements : public UnaryPrimitive {
|
||||
@ -1606,8 +1470,6 @@ class Pad : public UnaryPrimitive {
|
||||
std::vector<int> axes_;
|
||||
Shape low_pad_size_;
|
||||
Shape high_pad_size_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Partition : public UnaryPrimitive {
|
||||
@ -1630,8 +1492,6 @@ class Partition : public UnaryPrimitive {
|
||||
private:
|
||||
int kth_;
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Power : public UnaryPrimitive {
|
||||
@ -1646,9 +1506,6 @@ class Power : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Power)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class QuantizedMatmul : public UnaryPrimitive {
|
||||
@ -1679,8 +1536,6 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class GatherQMM : public UnaryPrimitive {
|
||||
@ -1706,8 +1561,6 @@ class GatherQMM : public UnaryPrimitive {
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class RandomBits : public UnaryPrimitive {
|
||||
@ -1728,8 +1581,6 @@ class RandomBits : public UnaryPrimitive {
|
||||
private:
|
||||
Shape shape_;
|
||||
int width_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Real : public UnaryPrimitive {
|
||||
@ -1837,9 +1688,6 @@ class Round : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Round)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Scan : public UnaryPrimitive {
|
||||
@ -1936,7 +1784,6 @@ class Scatter : public UnaryPrimitive {
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
ReduceType reduce_type_;
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
@ -1953,9 +1800,6 @@ class Sigmoid : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sigmoid)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Sign : public UnaryPrimitive {
|
||||
@ -1970,9 +1814,6 @@ class Sign : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sign)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Sin : public UnaryPrimitive {
|
||||
@ -1987,9 +1828,6 @@ class Sin : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sin)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Sinh : public UnaryPrimitive {
|
||||
@ -2004,9 +1842,6 @@ class Sinh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sinh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Slice : public UnaryPrimitive {
|
||||
@ -2036,7 +1871,6 @@ class Slice : public UnaryPrimitive {
|
||||
Shape start_indices_;
|
||||
Shape end_indices_;
|
||||
Shape strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
@ -2068,8 +1902,6 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
Shape start_indices_;
|
||||
Shape end_indices_;
|
||||
Shape strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class DynamicSlice : public UnaryPrimitive {
|
||||
@ -2136,7 +1968,6 @@ class Softmax : public UnaryPrimitive {
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool precise_;
|
||||
};
|
||||
|
||||
@ -2159,8 +1990,6 @@ class Sort : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Split : public Primitive {
|
||||
@ -2200,9 +2029,6 @@ class Square : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Square)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Sqrt : public UnaryPrimitive {
|
||||
@ -2230,7 +2056,6 @@ class Sqrt : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool recip_;
|
||||
};
|
||||
|
||||
@ -2262,9 +2087,6 @@ class Subtract : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Subtract)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Squeeze : public UnaryPrimitive {
|
||||
@ -2304,9 +2126,6 @@ class Tan : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Tan)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Tanh : public UnaryPrimitive {
|
||||
@ -2321,9 +2140,6 @@ class Tanh : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Tanh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Unflatten : public UnaryPrimitive {
|
||||
@ -2404,9 +2220,6 @@ class QRF : public Primitive {
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(QRF)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
/* SVD primitive. */
|
||||
@ -2421,9 +2234,6 @@ class SVD : public Primitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(SVD)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
/* Matrix inversion primitive. */
|
||||
@ -2442,7 +2252,6 @@ class Inverse : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& output);
|
||||
bool tri_;
|
||||
bool upper_;
|
||||
};
|
||||
@ -2462,7 +2271,6 @@ class Cholesky : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Cholesky)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& output);
|
||||
bool upper_;
|
||||
};
|
||||
|
||||
@ -2489,7 +2297,6 @@ class Eigh : public Primitive {
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
std::string uplo_;
|
||||
bool compute_eigenvectors_;
|
||||
};
|
||||
|
@ -14,6 +14,7 @@ inline constexpr bool can_convert_to_complex128 =
|
||||
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
||||
|
||||
struct complex128_t : public std::complex<double> {
|
||||
complex128_t() : std::complex<double>() {};
|
||||
complex128_t(double v, double u) : std::complex<double>(v, u) {};
|
||||
complex128_t(std::complex<double> v) : std::complex<double>(v) {};
|
||||
|
||||
@ -32,6 +33,7 @@ inline constexpr bool can_convert_to_complex64 =
|
||||
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
||||
|
||||
struct complex64_t : public std::complex<float> {
|
||||
complex64_t() : std::complex<float>() {};
|
||||
complex64_t(float v, float u) : std::complex<float>(v, u) {};
|
||||
complex64_t(std::complex<float> v) : std::complex<float>(v) {};
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
||||
|
||||
#include <arm_fp16.h>
|
||||
namespace mlx::core {
|
||||
typedef __fp16 float16_t;
|
||||
using ::float16_t;
|
||||
} // namespace mlx::core
|
||||
|
||||
#else
|
||||
@ -17,11 +18,12 @@ typedef struct _MLX_Float16 float16_t;
|
||||
} // namespace mlx::core
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
|
||||
#include <arm_bf16.h>
|
||||
namespace mlx::core {
|
||||
typedef __bf16 bfloat16_t;
|
||||
using ::bfloat16_t;
|
||||
} // namespace mlx::core
|
||||
|
||||
#else
|
||||
|
@ -741,7 +741,7 @@ void init_array(nb::module_& m) {
|
||||
[](const mx::array& a) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise inversion.");
|
||||
"Floating point types not allowed with bitwise inversion.");
|
||||
}
|
||||
if (a.dtype() != mx::bool_) {
|
||||
throw std::invalid_argument(
|
||||
@ -791,7 +791,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
"Floating point types not allowed with bitwise or.");
|
||||
}
|
||||
return mx::bitwise_or(a, b);
|
||||
},
|
||||
@ -806,7 +806,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
"Floating point types not allowed with bitwise or.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::bitwise_or(a, b));
|
||||
return a;
|
||||
@ -838,7 +838,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or left shift.");
|
||||
"Floating point types not allowed with left shift.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::left_shift(a, b));
|
||||
return a;
|
||||
@ -870,7 +870,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or right shift.");
|
||||
"Floating point types not allowed with right shift.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::right_shift(a, b));
|
||||
return a;
|
||||
|
@ -289,6 +289,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
|
||||
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
|
||||
|
||||
z = -mx.ones(64) % mx.full(64, 2)
|
||||
self.assertTrue(mx.array_equal(z, mx.ones(64)))
|
||||
|
||||
def test_comparisons(self):
|
||||
a = mx.array([0.0, 1.0, 5.0])
|
||||
b = mx.array([-1.0, 2.0, 5.0])
|
||||
|
@ -207,8 +207,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)
|
||||
w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
|
Loading…
Reference in New Issue
Block a user