mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-25 20:58:13 +08:00
awni's commit files
This commit is contained in:
9
mlx/backend/accelerate/CMakeLists.txt
Normal file
9
mlx/backend/accelerate/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
)
|
||||
167
mlx/backend/accelerate/matmul.cpp
Normal file
167
mlx/backend/accelerate/matmul.cpp
Normal file
@@ -0,0 +1,167 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <vecLib/cblas_new.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ 1.0,
|
||||
/* float beta = */ 0.0,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas(inputs[0], inputs[1], out);
|
||||
}
|
||||
return matmul_bnns(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
672
mlx/backend/accelerate/primitives.cpp
Normal file
672
mlx/backend/accelerate/primitives.cpp
Normal file
@@ -0,0 +1,672 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <vecLib/vDSP.h>
|
||||
#include <vecLib/vForce.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
|
||||
} else if (is_unsigned(in.dtype())) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, AbsOp());
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
auto allocfn = [&in, &out]() {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
};
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
allocfn();
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
allocfn();
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
allocfn();
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
allocfn();
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::two:
|
||||
vvlog2f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::ten:
|
||||
vvlog10f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x > y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x < y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x * y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
}
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int stride = in.shape(axis_);
|
||||
int count = in.size() / stride;
|
||||
const float* input = in.data<float>();
|
||||
float* output = out.data<float>();
|
||||
float s = 1.0;
|
||||
if (!reverse_) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||
input += stride;
|
||||
output += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
input += stride - 1;
|
||||
output += stride - 1;
|
||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
float minus_1 = -1;
|
||||
vDSP_vsmsa(
|
||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
float val = -(*s);
|
||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
int val = -(*s);
|
||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
147
mlx/backend/accelerate/reduce.cpp
Normal file
147
mlx/backend/accelerate/reduce.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
VT val = (*(VT*)x);
|
||||
*(VT*)a += val;
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a++ += *x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add proper templates for the strided reduce algorithm so we don't have
|
||||
// to write max/min/sum etc.
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = std::max(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = std::min(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_sum(const T* x, T* accum, int size) {
|
||||
VT _sum = {0};
|
||||
while (size >= N) {
|
||||
_sum += (*(VT*)x);
|
||||
x += N;
|
||||
size -= N;
|
||||
}
|
||||
T sum = _sum[0];
|
||||
for (int i = 1; i < N; i++) {
|
||||
sum += _sum[i];
|
||||
}
|
||||
*accum += sum;
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.dtype() == float32) {
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_sum<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
(*accum) += acc;
|
||||
},
|
||||
[](auto* accum, auto x) { *accum += x; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Max) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_max<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
(*accum) = (*accum < max) ? max : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Min) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_min<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
(*accum) = (*accum > min) ? min : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: Add integer addition and min/max using the templates above and
|
||||
// simd_int16 and friends.
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
18
mlx/backend/common/CMakeLists.txt
Normal file
18
mlx/backend/common/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
)
|
||||
72
mlx/backend/common/arange.h
Normal file
72
mlx/backend/common/arange.h
Normal file
@@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
110
mlx/backend/common/arg_reduce.cpp
Normal file
110
mlx/backend/common/arg_reduce.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x < (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x > (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
541
mlx/backend/common/conv.cpp
Normal file
541
mlx/backend/common/conv.cpp
Normal file
@@ -0,0 +1,541 @@
|
||||
#include <cassert>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Naive reference conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const T* start_wt_ptr = wt.data<T>();
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_C = in.strides()[2];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_H = wt.strides()[1];
|
||||
const size_t wt_stride_C = wt.strides()[2];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_H = out.strides()[1];
|
||||
const size_t out_stride_O = out.strides()[2];
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int o = 0; o < O; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
|
||||
|
||||
if (ih >= 0 && ih < iH) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih check
|
||||
} // wh
|
||||
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
} // oh
|
||||
|
||||
in_ptr += in_stride_N;
|
||||
out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_W = in.strides()[2];
|
||||
const size_t in_stride_C = in.strides()[3];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_H = wt.strides()[1];
|
||||
const size_t wt_stride_W = wt.strides()[2];
|
||||
const size_t wt_stride_C = wt.strides()[3];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_H = out.strides()[1];
|
||||
const size_t out_stride_W = out.strides()[2];
|
||||
const size_t out_stride_O = out.strides()[3];
|
||||
|
||||
auto pt_conv_no_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
auto pt_conv_all_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
|
||||
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
|
||||
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case a: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case b: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case c: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
|
||||
} // oh
|
||||
|
||||
// Case 3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_1D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_1D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_2D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_2D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void explicit_gemm_conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, wH, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH, wH * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Peform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = out.dtype();
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset =
|
||||
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Peform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Conv routing
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
if (wt_dilation[0] == 1) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
|
||||
void conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
}
|
||||
// 1D convolution
|
||||
else if (in.ndim() == (1 + 2)) {
|
||||
return conv_1D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
std::ostringstream msg;
|
||||
msg << "[Convolution::eval] Convolution currently only supports"
|
||||
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " spatial dimensions";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
308
mlx/backend/common/copy.cpp
Normal file
308
mlx/backend/common/copy.cpp
Normal file
@@ -0,0 +1,308 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst) {
|
||||
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (int i = 0; i < dst.size(); ++i) {
|
||||
dst_ptr[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim1(const array& src, array& dst) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim2(const array& src, array& dst) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[1];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim3(const array& src, array& dst) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[2];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim4(const array& src, array& dst) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[3];
|
||||
}
|
||||
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(const array& src, array& dst) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
inline void copy_general_general_dims(
|
||||
const array& src,
|
||||
array& dst,
|
||||
size_t offset_src,
|
||||
size_t offset_dst) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = src.ndim() - D;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
for (int i = 0; i < N; i++) {
|
||||
copy_general_general_dims<SrcT, DstT, D - 1>(
|
||||
src, dst, offset_src, offset_dst);
|
||||
offset_src += stride_src;
|
||||
offset_dst += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = src.ndim() - 1;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
|
||||
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
|
||||
for (int i = 0; i < N; i++) {
|
||||
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||
src_ptr += stride_src;
|
||||
dst_ptr += stride_dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_general(const array& src, array& dst) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
|
||||
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::Vector:
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
// Allocate the output
|
||||
switch (ctype) {
|
||||
case CopyType::Vector:
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
|
||||
src.data_size(),
|
||||
src.strides(),
|
||||
src.flags());
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
||||
break;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
130
mlx/backend/common/default_primitives.cpp
Normal file
130
mlx/backend/common/default_primitives.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
#include <cblas.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
DEFAULT(Add)
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArcCos)
|
||||
DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(Multiply)
|
||||
DEFAULT(Negative)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
38
mlx/backend/common/erf.cpp
Normal file
38
mlx/backend/common/erf.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#include <cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Approximation to the inverse error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||
*/
|
||||
float erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
10
mlx/backend/common/erf.h
Normal file
10
mlx/backend/common/erf.h
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Approximation to the inverse error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||
*/
|
||||
float erfinv(float a);
|
||||
|
||||
} // namespace mlx::core
|
||||
377
mlx/backend/common/indexing.cpp
Normal file
377
mlx/backend/common/indexing.cpp
Normal file
@@ -0,0 +1,377 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size
|
||||
// If the array is col contiguous then the reverse is the case:
|
||||
// - Any number of trailing ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size from the end
|
||||
|
||||
bool can_copy = false;
|
||||
if (src.flags().row_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
// Check the remaining
|
||||
i++;
|
||||
for (; i < src.ndim() && can_copy; ++i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
} else if (src.flags().col_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
// Skip the next slice size and check the remaining
|
||||
i--;
|
||||
for (; i >= 0 && can_copy; --i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
void dispatch_gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint8:
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint16:
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint32:
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint64:
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int8:
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int16:
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int32:
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int64:
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float16:
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case complex64:
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval] Cannot gather with floating point indices.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT, typename OpT>
|
||||
void scatter(
|
||||
const array& updates,
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
std::vector<int> update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||
op(updates.data<InT>()[update_loc],
|
||||
out.data<InT>() + out_offset + out_loc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
void dispatch_scatter_inds(
|
||||
array& out,
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void dispatch_scatter(
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||
}
|
||||
}
|
||||
|
||||
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
copy(src, out, CopyType::General);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
52
mlx/backend/common/load.cpp
Normal file
52
mlx/backend/common/load.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
|
||||
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
622
mlx/backend/common/primitives.cpp
Normal file
622
mlx/backend/common/primitives.cpp
Normal file
@@ -0,0 +1,622 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/erf.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (is_unsigned(in.dtype())) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, AbsOp());
|
||||
}
|
||||
}
|
||||
|
||||
void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
}
|
||||
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::acos(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccos] Cannot compute inverse cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::acosh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::asin(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsin] Cannot compute inverse sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::asinh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::atan(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan] Cannot compute inverse tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::atanh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::cos(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cos] Cannot compute cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::cosh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
||||
break;
|
||||
case float16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float16_t>(in, out, [](auto x) {
|
||||
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
case bfloat16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
||||
break;
|
||||
case float16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float16_t>(in, out, [](auto x) {
|
||||
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
case bfloat16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, [](auto x) { return std::log(x); });
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, [](auto x) { return std::log2(x); });
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, [](auto x) { return std::log10(x); });
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, [](auto x) { return !x; });
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
}
|
||||
|
||||
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy(val, out, CopyType::Scalar);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
|
||||
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
auto sigmoid_op = [](auto x) {
|
||||
auto one = static_cast<decltype(x)>(1.0);
|
||||
return one / (one + std::exp(-x));
|
||||
};
|
||||
unary_fp(in, out, sigmoid_op);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sigmoid] Cannot sigmoid of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, SignOp());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::sin(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sin] Cannot compute sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::sinh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
auto flags = in.flags();
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
strides[i] *= strides_[i];
|
||||
}
|
||||
|
||||
// Compute row/col contiguity
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
|
||||
f_stride *= out.shape(i);
|
||||
b_stride *= out.shape(ri);
|
||||
if (strides[i] > 0) {
|
||||
data_size *= out.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
}
|
||||
|
||||
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, [](auto x) {
|
||||
return static_cast<decltype(x)>(1.0) / sqrt(x);
|
||||
});
|
||||
} else {
|
||||
unary_fp(in, out, [](auto x) { return sqrt(x); });
|
||||
}
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::tan(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tan] Cannot compute tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::tanh(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
215
mlx/backend/common/reduce.cpp
Normal file
215
mlx/backend/common/reduce.cpp
Normal file
@@ -0,0 +1,215 @@
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const U max;
|
||||
static const U min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr type max = std::numeric_limits<type>::max(); \
|
||||
static constexpr type min = std::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static const type max; \
|
||||
static const type min; \
|
||||
};
|
||||
|
||||
instantiate_float_limit(float16_t);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(complex64_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr bool max = true;
|
||||
static constexpr bool min = false;
|
||||
};
|
||||
|
||||
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
||||
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) &= (b != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) &= x;
|
||||
}
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) |= (b != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) |= x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
switch (rtype) {
|
||||
case Reduce::And: {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Or: {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||
break;
|
||||
case uint8:
|
||||
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint16:
|
||||
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint32:
|
||||
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint64:
|
||||
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int8:
|
||||
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int16:
|
||||
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int32:
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int64:
|
||||
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case float16:
|
||||
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case float32:
|
||||
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case complex64:
|
||||
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||
break;
|
||||
}
|
||||
} break;
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
364
mlx/backend/common/reduce.h
Normal file
364
mlx/backend/common/reduce.h
Normal file
@@ -0,0 +1,364 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum ReductionOpType {
|
||||
// Self-explanatory. Read everything and produce 1 output.
|
||||
ContiguousAllReduce,
|
||||
|
||||
// The input is contiguous and the last axis is reduced
|
||||
// N1xR1xN2xR2x...xNnxRn
|
||||
ContiguousReduce,
|
||||
|
||||
// The input is contiguous and the last axis is not reduced
|
||||
// R1xN1xR2xN2x...xRnxNn
|
||||
ContiguousStridedReduce,
|
||||
|
||||
// The input is not contiguous but the last axis is and it is reduced so we
|
||||
// need to figure out the offsets but we can call the contiguous reduce after
|
||||
// that.
|
||||
// N3xR1xN1xR4x...xRn
|
||||
GeneralContiguousReduce,
|
||||
|
||||
// The input is not contiguous but the last reduction axis and the last axis
|
||||
// are so we need to figure out the offset but we can call the strided reduce
|
||||
// after that.
|
||||
GeneralStridedReduce,
|
||||
|
||||
// The input is not contiguous after the reduction axis and it may contain
|
||||
// 0-stride axes or transpositions. We could copy the strides and produce a
|
||||
// transposed outcome or we can read the input out of order and write the
|
||||
// output in order.
|
||||
GeneralReduce
|
||||
};
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedReduce {
|
||||
Op op;
|
||||
|
||||
DefaultStridedReduce(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* x, U* accumulator, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
U* moving_accumulator = accumulator;
|
||||
for (int j = 0; j < stride; j++) {
|
||||
op(moving_accumulator, *x);
|
||||
moving_accumulator++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultContiguousReduce {
|
||||
Op op;
|
||||
|
||||
DefaultContiguousReduce(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* x, U* accumulator, int size) {
|
||||
while (size-- > 0) {
|
||||
op(accumulator, *x);
|
||||
x++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
OpS ops,
|
||||
OpC opc,
|
||||
Op op) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
opc(x.data<T>(), out_ptr, x.size());
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
opc(x_ptr, out_ptr, reduction_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
opc(x_ptr + offset, out_ptr, reduction_size);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralStridedReduce ||
|
||||
plan.type == ContiguousStridedReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
ops(x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
DefaultStridedReduce<T, U, Op> ops(op);
|
||||
DefaultContiguousReduce<T, U, Op> opc(op);
|
||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
323
mlx/backend/common/scan.cpp
Normal file
323
mlx/backend/common/scan.cpp
Normal file
@@ -0,0 +1,323 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultContiguousScan {
|
||||
Op op;
|
||||
U init;
|
||||
|
||||
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||
|
||||
void operator()(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input++;
|
||||
output++;
|
||||
op(output, output - 1, input);
|
||||
}
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
op(output + 1, output, input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input--;
|
||||
output--;
|
||||
op(output, output + 1, input);
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
op(output - 1, output, input);
|
||||
input--;
|
||||
output--;
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedScan {
|
||||
Op op;
|
||||
U init;
|
||||
|
||||
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||
|
||||
void operator()(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int size,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
// TODO: Vectorize the following naive implementation
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::copy(input, input + stride, output);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
op(output, output - stride, input);
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::fill(output, output + stride, init);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
op(output, output - stride, input - stride);
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::copy(input, input + stride, output);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
op(output, output + stride, input);
|
||||
}
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::fill(output, output + stride, init);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
op(output, output + stride, input + stride);
|
||||
}
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename OpCS, typename OpSS>
|
||||
void scan_op(
|
||||
OpCS opcs,
|
||||
OpSS opss,
|
||||
const array& input,
|
||||
array& output,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||
|
||||
if (input.flags().row_contiguous) {
|
||||
if (input.strides()[axis] == 1) {
|
||||
opcs(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis),
|
||||
input.shape(axis),
|
||||
reverse,
|
||||
inclusive);
|
||||
} else {
|
||||
opss(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis) / input.strides()[axis],
|
||||
input.shape(axis),
|
||||
input.strides()[axis],
|
||||
reverse,
|
||||
inclusive);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void scan_dispatch(
|
||||
Scan::ReduceType rtype,
|
||||
const array& input,
|
||||
array& output,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
switch (rtype) {
|
||||
case Scan::Sum: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
|
||||
auto init = static_cast<U>(0);
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
|
||||
auto init = static_cast<U>(1);
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Scan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General);
|
||||
in = arr_copy;
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
// where we accumulate in a different type, for now.
|
||||
//
|
||||
// TODO: If we add the option to accumulate floats in higher precision
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
29
mlx/backend/common/threefry.cpp
Normal file
29
mlx/backend/common/threefry.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
|
||||
namespace mlx::core::random {
|
||||
|
||||
std::pair<uint32_t, uint32_t> threefry2x32_hash(
|
||||
const std::pair<uint32_t, uint32_t>& key,
|
||||
std::pair<uint32_t, uint32_t> count) {
|
||||
constexpr static uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6}, {17, 29, 16, 24}};
|
||||
|
||||
uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
|
||||
|
||||
count.first += ks[0];
|
||||
count.second += ks[1];
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (auto r : rotations[i % 2]) {
|
||||
count.first += count.second;
|
||||
count.second = (count.second << r) | (count.second >> (32 - r));
|
||||
count.second ^= count.first;
|
||||
}
|
||||
count.first += ks[(i + 1) % 3];
|
||||
count.second += ks[(i + 2) % 3] + i + 1;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::random
|
||||
19
mlx/backend/common/threefry.h
Normal file
19
mlx/backend/common/threefry.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::random {
|
||||
|
||||
/** Applies the Threefry 2x32 hash function.
|
||||
* This code is based on the Jax counter-based and splittable PRNG
|
||||
* https://github.com/google/jax/blob/main/docs/jep/263-prng.md
|
||||
*
|
||||
* Original Threefry reference:
|
||||
* http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
*/
|
||||
std::pair<uint32_t, uint32_t> threefry2x32_hash(
|
||||
const std::pair<uint32_t, uint32_t>& key,
|
||||
std::pair<uint32_t, uint32_t> count);
|
||||
|
||||
} // namespace mlx::core::random
|
||||
147
mlx/backend/common/unary.h
Normal file
147
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,147 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
struct AbsOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct SignOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
T* dst = out.data<T>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
T* dst = out.data<T>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
// TODO this is super inefficient, need to fix.
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
dst[i] = op(a_ptr[a_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
unary_op<bool>(a, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_fp(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_fp] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
29
mlx/backend/common/utils.h
Normal file
29
mlx/backend/common/utils.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
size_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
elem = q_and_r.quot;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(int elem, const array& a) {
|
||||
if (a.flags().row_contiguous) {
|
||||
return elem;
|
||||
}
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
26
mlx/backend/metal/CMakeLists.txt
Normal file
26
mlx/backend/metal/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||
|
||||
target_compile_definitions(
|
||||
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
||||
113
mlx/backend/metal/copy.cpp
Normal file
113
mlx/backend/metal/copy.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||
auto& strides_in = strides[0];
|
||||
auto& strides_out = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "scopy";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "vcopy";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "gcopy";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "ggcopy";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
size_t ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||
}
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(
|
||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = in.size() / (dim0 * dim1);
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
81
mlx/backend/metal/device.h
Normal file
81
mlx/backend/metal/device.h
Normal file
@@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
~Device();
|
||||
|
||||
MTL::Device* mtl_device() {
|
||||
return device_;
|
||||
};
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* new_command_buffer(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name = "mlx");
|
||||
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
NS::AutoreleasePool* pool_;
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
296
mlx/backend/metal/indexing.cpp
Normal file
296
mlx/backend/metal/indexing.cpp
Normal file
@@ -0,0 +1,296 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
int nidx = inputs.size() - 1;
|
||||
|
||||
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Gather::eval_gpu] Gathering with more than "
|
||||
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
size_t nthreads = out.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (size_of(out.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int nidx = axes_.size();
|
||||
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Scatter::eval_gpu] Gathering with more than "
|
||||
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Get stream
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
kname << "_sum";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
kname << "_prod";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
kname << "_max";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
kname << "_min";
|
||||
break;
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
320
mlx/backend/metal/kernels/atomic.h
Normal file
320
mlx/backend/metal/kernels/atomic.h
Normal file
@@ -0,0 +1,320 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Atomic utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
template <typename T>
|
||||
constexpr constant bool is_metal_atomic = _disjunction<
|
||||
is_same<T, int>,
|
||||
is_same<T, uint>,
|
||||
is_same<T, ulong>,
|
||||
is_same<T, float>>::value;
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct mlx_atomic {
|
||||
atomic<uint> val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
||||
atomic<T> val;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Native metal atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
T expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val * expected, offset)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread T* expected,
|
||||
T val,
|
||||
int offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
val,
|
||||
memory_order_relaxed,
|
||||
memory_order_relaxed);
|
||||
}
|
||||
|
||||
// Specialization for float since it does not atomic_fetch_min_explicit
|
||||
template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
int offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val < expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val, offset)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for float since it does not atomic_fetch_max_explicit
|
||||
template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
int offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val > expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val, offset)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Custom atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
|
||||
|
||||
template <typename T>
|
||||
union uint_or_packed {
|
||||
T val[packing_size<T>];
|
||||
uint bits;
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct mlx_atomic_update_helper {
|
||||
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
|
||||
Op op;
|
||||
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||
return init.bits;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
METAL_FUNC void mlx_atomic_update_and_store(
|
||||
device mlx_atomic<T>* object,
|
||||
T update,
|
||||
int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
|
||||
mlx_atomic_update_helper<T, Op> helper;
|
||||
uint_or_packed<T> expected;
|
||||
expected.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
|
||||
while (Op::condition(update, expected.val[elem_offset]) &&
|
||||
!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object,
|
||||
&(expected.bits),
|
||||
helper(expected, update, elem_offset),
|
||||
pack_offset)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct __None {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
#pragma unused(b)
|
||||
return true;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
#pragma unused(b)
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Add {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
#pragma unused(b)
|
||||
return true;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Mul {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
return b != 0;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Max {
|
||||
static bool condition(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Min {
|
||||
static bool condition(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||
int pack_offset = offset / sizeof(T);
|
||||
int elem_offset = offset % sizeof(T);
|
||||
uint_or_packed<T> packed_val;
|
||||
packed_val.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
return packed_val.val[elem_offset];
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = __UINT32_MAX__;
|
||||
identity.val[elem_offset] = val;
|
||||
|
||||
atomic_fetch_and_explicit(
|
||||
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = 0;
|
||||
identity.val[elem_offset] = val;
|
||||
|
||||
atomic_fetch_or_explicit(
|
||||
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread uint* expected,
|
||||
uint val,
|
||||
int offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
val,
|
||||
memory_order_relaxed,
|
||||
memory_order_relaxed);
|
||||
}
|
||||
315
mlx/backend/metal/kernels/bf16.h
Normal file
315
mlx/backend/metal/kernels/bf16.h
Normal file
@@ -0,0 +1,315 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
return x != x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
369
mlx/backend/metal/kernels/binary.metal
Normal file
369
mlx/backend/metal/kernels/binary.metal
Normal file
@@ -0,0 +1,369 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x >= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x <= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_float(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
instantiate_binary_types_bool(le, Less)
|
||||
instantiate_binary_types_bool(leq, LessEqual)
|
||||
instantiate_binary_types_bool(neq, NotEqual)
|
||||
instantiate_binary_float(lae, LogAddExp)
|
||||
instantiate_binary_types(max, Maximum)
|
||||
instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
110
mlx/backend/metal/kernels/complex.h
Normal file
110
mlx/backend/metal/kernels/complex.h
Normal file
@@ -0,0 +1,110 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_complex64 =
|
||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_complex64 =
|
||||
!is_same_v<T, complex64_t> &&
|
||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||
|
||||
struct complex64_t {
|
||||
float real;
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Converstions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const thread {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const threadgroup {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const device {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const constant {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr complex64_t operator-(complex64_t x) {
|
||||
return {-x.real, -x.imag};
|
||||
}
|
||||
|
||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
return a.real == b.real && a.imag == b.imag;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
14
mlx/backend/metal/kernels/defines.h
Normal file
14
mlx/backend/metal/kernels/defines.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
@@ -0,0 +1,479 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BM;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
int offsets_n[n_rows];
|
||||
int offsets_oh[n_rows];
|
||||
int offsets_ow[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params.oS[0] * params.oS[1];
|
||||
|
||||
for (int i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = tid.y * BM + bi + i * bstride;
|
||||
offsets_n[i] = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
offsets_oh[i] = hw / params.oS[1];
|
||||
offsets_ow[i] = hw % params.oS[1];
|
||||
}
|
||||
|
||||
(void)lid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
||||
int n = offsets_n[i];
|
||||
int oh = offsets_oh[i];
|
||||
int ow = offsets_ow[i];
|
||||
|
||||
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
||||
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const device T* curr_src = src + n * params.in_strides[0] +
|
||||
ih * params.in_strides[1] + iw * params.in_strides[2];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwize
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BN;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_.wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
(void)lid;
|
||||
(void)tid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src =
|
||||
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DBlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DBlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DImplicitGEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
||||
using mma_t = Conv2DBlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const int K = params.wt_strides[0];
|
||||
const int N = params.O;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
}
|
||||
};
|
||||
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
@@ -0,0 +1,536 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BROWS,
|
||||
int BCOLS,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
bool transpose,
|
||||
bool ldK,
|
||||
int tgp_padding = 0>
|
||||
struct BlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
||||
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
// Stride along reduction axis between blocks
|
||||
const int tstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tstride(
|
||||
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
||||
|
||||
// Iterate over rows of block
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
// Row is in bounds, we check against column
|
||||
if ((bi + i) < src_tile_dim.y) {
|
||||
// Use fast thread memory for bound checks
|
||||
short tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indcies into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Row is out of bounds, we just fill tgp memory with zeros
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tstride;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct GEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
BM,
|
||||
BK,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_a,
|
||||
true,
|
||||
tgp_padding_a>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
BK,
|
||||
BN,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_b,
|
||||
false,
|
||||
tgp_padding_b>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant int& M [[buffer(3)]],
|
||||
const constant int& N [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& batch_stride_a [[buffer(6)]],
|
||||
const constant int& batch_stride_b [[buffer(7)]],
|
||||
const constant int& batch_stride_c [[buffer(8)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Adjust for batch
|
||||
A += batch_stride_a * tid.z;
|
||||
B += batch_stride_b * tid.z;
|
||||
C += batch_stride_c * tid.z;
|
||||
|
||||
// Adjust for transpose
|
||||
const int lda_dev = transpose_a ? M : K;
|
||||
const int ldb_dev = transpose_b ? K : N;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * K;
|
||||
B += transpose_b ? c_col * K : c_col;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
||||
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned && K_aligned) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN aligned, K unaligned loop
|
||||
else if (MN_aligned && !K_aligned) {
|
||||
// Main loop
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Loop tail
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
|
||||
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
||||
|
||||
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
} else {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
mma_op.store_result_safe(C, N, src_tile_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
302
mlx/backend/metal/kernels/gemv.metal
Normal file
302
mlx/backend/metal/kernels/gemv.metal
Normal file
@@ -0,0 +1,302 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constant constexpr int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T in_vec_block[BN][TN * 2];
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
// Block position
|
||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if(out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * in_vec_size;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[simd_lid][tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if(simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float)
|
||||
instantiate_gemv_blocks(float16, half)
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T tgp_results[BN][BM][TM];
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
// Edgecase handling
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
for(int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.x][lid.y][i] = result[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
// Threadgroup accumulation
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
for(int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[lid.x][i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
|
||||
instantiate_gemv_t_blocks(float32, float)
|
||||
instantiate_gemv_t_blocks(float16, half)
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)
|
||||
226
mlx/backend/metal/kernels/softmax.metal
Normal file
226
mlx/backend/metal/kernels/softmax.metal
Normal file
@@ -0,0 +1,226 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[0] = normalizer;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||
// lsize) parts. We need to combine them.
|
||||
// 1. We start by finding the max across simd groups
|
||||
// 2. We then change the partial normalizers to account for a possible
|
||||
// change in max
|
||||
// 3. We sum all normalizers
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||
// them shared memory and combine them.
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
818
mlx/backend/metal/kernels/sort.metal
Normal file
818
mlx/backend/metal/kernels/sort.metal
Normal file
@@ -0,0 +1,818 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
using namespace metal;\
|
||||
|
||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Thread-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||
T w = a;
|
||||
a = b;
|
||||
b = w;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct LessThan {
|
||||
static constexpr constant T init = Limits<T>::max;
|
||||
|
||||
METAL_FUNC bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct ThreadSort {
|
||||
static METAL_FUNC void sort(
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if(op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Threadgroup-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct BlockMergeSort {
|
||||
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
static METAL_FUNC int merge_partition(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
short sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
short A_st = max(0, sort_md - B_sz);
|
||||
short A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
short md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void merge_step(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
const threadgroup idx_t* As_idx,
|
||||
const threadgroup idx_t* Bs_idx,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
short a_idx = 0;
|
||||
short b_idx = 0;
|
||||
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
auto a = As[a_idx];
|
||||
auto b = Bs[b_idx];
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
|
||||
vals[i] = pred ? b : a;
|
||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void sort(
|
||||
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||
int size_sorted_axis,
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Get thread location
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
|
||||
// Load from shared memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
thread_vals[i] = tgp_vals[idx + i];
|
||||
if(ARG_SORT) {
|
||||
thread_idxs[i] = tgp_idxs[idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread sort
|
||||
if(idx < size_sorted_axis) {
|
||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Do merges using threadgroup memory
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
||||
// Update threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
|
||||
int sort_sz = N_PER_THREAD * merge_threads;
|
||||
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||
|
||||
// As = tgp_vals[A_st:A_ed] is sorted
|
||||
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||
int A_st = sort_st;
|
||||
int A_ed = sort_st + sort_sz / 2;
|
||||
int B_st = sort_st + sort_sz / 2;
|
||||
int B_ed = sort_st + sort_sz;
|
||||
|
||||
const threadgroup val_t* As = tgp_vals + A_st;
|
||||
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Find a partition of merge elements
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// of size N_PER_THREAD for each merge lane i
|
||||
// C = [Ci] is sorted
|
||||
int sort_md = N_PER_THREAD * merge_lane;
|
||||
int partition = merge_partition(
|
||||
As,
|
||||
Bs,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md);
|
||||
|
||||
As += partition;
|
||||
Bs += sort_md - partition;
|
||||
|
||||
A_sz -= partition;
|
||||
B_sz -= sort_md - partition;
|
||||
|
||||
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
|
||||
// Merge starting at the partition and store results in thread registers
|
||||
merge_step(
|
||||
As,
|
||||
Bs,
|
||||
As_idx,
|
||||
Bs_idx,
|
||||
A_sz,
|
||||
B_sz,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
}
|
||||
|
||||
// Write out to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Kernel sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<T>>
|
||||
struct KernelMergeSort {
|
||||
using val_t = T;
|
||||
using idx_t = uint;
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
|
||||
if(ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
if(ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
constant constexpr const int zero_helper = 0;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
|
||||
if(ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||
|
||||
#define instantiate_block_sort_bn(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256) \
|
||||
instantiate_block_sort_tn(itname, itype, 512)
|
||||
|
||||
instantiate_block_sort_bn(uint8, uint8_t)
|
||||
instantiate_block_sort_bn(uint16, uint16_t)
|
||||
instantiate_block_sort_bn(uint32, uint32_t)
|
||||
instantiate_block_sort_bn(int8, int8_t)
|
||||
instantiate_block_sort_bn(int16, int16_t)
|
||||
instantiate_block_sort_bn(int32, int32_t)
|
||||
instantiate_block_sort_bn(float16, half)
|
||||
instantiate_block_sort_bn(float32, float)
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_block_sort_long(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256)
|
||||
|
||||
instantiate_block_sort_long(uint64, uint64_t)
|
||||
instantiate_block_sort_long(int64, int64_t)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multi block merge sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device val_t* inp,
|
||||
device val_t* out_vals,
|
||||
device idx_t* out_idxs,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
int base_idx = tid.x * N_PER_BLOCK;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
tgp_idxs[i] = idx;
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
out_vals[idx] = tgp_vals[i];
|
||||
out_idxs[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC int merge_partition(
|
||||
const device val_t* As,
|
||||
const device val_t* Bs,
|
||||
int A_sz,
|
||||
int B_sz,
|
||||
int sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
int A_st = max(0, sort_md - B_sz);
|
||||
int A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
int md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||
const device val_t* inp [[buffer(0)]],
|
||||
device val_t* out_vals [[buffer(1)]],
|
||||
device idx_t* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out_vals += tid.y * size_sorted_axis;
|
||||
out_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
block_partitions += tid.y * tgp_dims.x;
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st,
|
||||
dev_vals + B_st,
|
||||
A_ed - A_st,
|
||||
B_ed - B_st,
|
||||
partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
||||
const device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals_in [[buffer(1)]],
|
||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||
device val_t* dev_vals_out [[buffer(3)]],
|
||||
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||
|
||||
block_partitions += tid.y * (num_tiles + 1);
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_out += tid.y * size_sorted_axis;
|
||||
dev_idxs_out += tid.y * size_sorted_axis;
|
||||
|
||||
int block_idx = tid.x;
|
||||
int merge_group = block_idx / merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||
|
||||
int A_st = block_partitions[block_idx + 0];
|
||||
int A_ed = block_partitions[block_idx + 1];
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
|
||||
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
|
||||
if((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
|
||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||
}
|
||||
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Load from global memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
if(idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
||||
} else {
|
||||
thread_vals[i] = CompareOp::init;
|
||||
thread_idxs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Write to shared memory
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
tgp_vals[idx] = thread_vals[i];
|
||||
tgp_idxs[idx] = thread_idxs[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Merge
|
||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||
|
||||
int A_st_local = block_sort_t::merge_partition(
|
||||
tgp_vals,
|
||||
tgp_vals + A_sz,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md_local);
|
||||
int A_ed_local = A_sz;
|
||||
|
||||
int B_st_local = sort_md_local - A_st_local;
|
||||
int B_ed_local = B_sz;
|
||||
|
||||
int A_sz_local = A_ed_local - A_st_local;
|
||||
int B_sz_local = B_ed_local - B_st_local;
|
||||
|
||||
// Do merge
|
||||
block_sort_t::merge_step(
|
||||
tgp_vals + A_st_local,
|
||||
tgp_vals + A_ed_local + B_st_local,
|
||||
tgp_idxs + A_st_local,
|
||||
tgp_idxs + A_ed_local + B_st_local,
|
||||
A_sz_local,
|
||||
B_sz_local,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Write output
|
||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
dev_vals_out[idx] = tgp_vals[i];
|
||||
dev_idxs_out[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
||||
instantiate_multi_block_sort_base(uint8, uint8_t)
|
||||
instantiate_multi_block_sort_base(uint16, uint16_t)
|
||||
instantiate_multi_block_sort_base(uint32, uint32_t)
|
||||
instantiate_multi_block_sort_base(int8, int8_t)
|
||||
instantiate_multi_block_sort_base(int16, int16_t)
|
||||
instantiate_multi_block_sort_base(int32, int32_t)
|
||||
instantiate_multi_block_sort_base(float16, half)
|
||||
instantiate_multi_block_sort_base(float32, float)
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||
|
||||
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||
instantiate_multi_block_sort_long(int64, int64_t)
|
||||
244
mlx/backend/metal/kernels/utils.h
Normal file
244
mlx/backend/metal/kernels/utils.h
Normal file
@@ -0,0 +1,244 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
metal::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = \
|
||||
metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type min = \
|
||||
-metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
-metal::numeric_limits<type>::max(); \
|
||||
};
|
||||
|
||||
instantiate_float_limit(half);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr constant bool max = true;
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
bfloat16_t ret =
|
||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
return ret;
|
||||
}
|
||||
446
mlx/backend/metal/matmul.cpp
Normal file
446
mlx/backend/metal/matmul.cpp
Normal file
@@ -0,0 +1,446 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
bool use_mps() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
|
||||
return std::string(buff_str) != "OFF";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
static bool use_mps_ = get_val();
|
||||
return use_mps_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
inline void mps_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||
|
||||
if (out.dtype() == float16) {
|
||||
mps_dtype = MPS::DataTypeFloat16;
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
mps_dtype = MPS::DataTypeBFloat16;
|
||||
}
|
||||
|
||||
// Used batched MPSMatrixMultiplication if batch_size_out > 1
|
||||
// We only accept the following cases:
|
||||
// 1. Both a, b have batch_size_out matrices worth of data
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
auto batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
auto matrix_stride_a = M * K;
|
||||
auto matrix_stride_b = K * N;
|
||||
auto matrix_stride_out = M * N;
|
||||
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// in data, no broadcasted strides considered
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
|
||||
// Handle simple broadcasting
|
||||
if (std::min(batch_size_a, batch_size_b) == 1) {
|
||||
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
|
||||
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
|
||||
|
||||
batch_size_a = batch_size_out;
|
||||
batch_size_b = batch_size_out;
|
||||
}
|
||||
|
||||
// Only proceed if broadcasting between a and b is simple
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// after broadcasting
|
||||
if (batch_size_a == batch_size_b) {
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(M * K) / lda,
|
||||
lda,
|
||||
batch_size_a,
|
||||
lda * a.itemsize(),
|
||||
(matrix_stride_a * a.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(K * N) / ldb,
|
||||
ldb,
|
||||
batch_size_b,
|
||||
ldb * b.itemsize(),
|
||||
(matrix_stride_b * b.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
M,
|
||||
N,
|
||||
batch_size_out,
|
||||
N * out.itemsize(),
|
||||
matrix_stride_out * out.itemsize(),
|
||||
mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
kernel->setBatchSize(batch_size_out);
|
||||
kernel->setBatchStart(0);
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
|
||||
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
|
||||
kernel->setLeftMatrixOrigin({a_row, 0, 0});
|
||||
kernel->setRightMatrixOrigin({b_row, 0, 0});
|
||||
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
}
|
||||
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlx_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Other launch kernels with set offsets
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
d.end_encoding(s.index);
|
||||
|
||||
if (use_mps()) {
|
||||
mps_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
return;
|
||||
}
|
||||
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
29
mlx/backend/metal/matmul.h
Normal file
29
mlx/backend/metal/matmul.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void mlx_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies);
|
||||
|
||||
} // namespace mlx::core
|
||||
368
mlx/backend/metal/mps/gemm.h
Normal file
368
mlx/backend/metal/mps/gemm.h
Normal file
@@ -0,0 +1,368 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
|
||||
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
||||
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
||||
|
||||
namespace MTL::Private::Class {
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
||||
} // namespace MTL::Private::Class
|
||||
|
||||
namespace MTL::Private::Selector {
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
||||
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_,
|
||||
"initWithDevice:transposeLeft:transposeRight:"
|
||||
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
||||
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
||||
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_dataType,
|
||||
"vectorDescriptorWithLength:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
||||
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_transpose_rows_columns_alpha_beta,
|
||||
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
||||
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
||||
} // namespace MTL::Private::Selector
|
||||
|
||||
namespace MPS {
|
||||
|
||||
typedef enum DataType : uint32_t {
|
||||
DataTypeFloatBit = 0x10000000,
|
||||
DataTypeAlternateEncodingBit = 0x80000000,
|
||||
DataTypeFloat16 = DataTypeFloatBit | 16,
|
||||
DataTypeFloat32 = DataTypeFloatBit | 32,
|
||||
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
||||
} DataType;
|
||||
|
||||
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
||||
public:
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType);
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType);
|
||||
NS::UInteger rows() const;
|
||||
};
|
||||
|
||||
class Matrix : public NS::Referencing<Matrix> {
|
||||
public:
|
||||
static class Matrix* alloc();
|
||||
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class Kernel : public NS::Referencing<Kernel> {
|
||||
public:
|
||||
NS::String* label() const;
|
||||
MTL::Device* device() const;
|
||||
};
|
||||
|
||||
class MatrixMultiplication
|
||||
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixMultiplication* alloc();
|
||||
|
||||
MatrixMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix);
|
||||
|
||||
void setLeftMatrixOrigin(MTL::Origin origin);
|
||||
void setRightMatrixOrigin(MTL::Origin origin);
|
||||
void setResultMatrixOrigin(MTL::Origin origin);
|
||||
void setBatchStart(NS::UInteger batchStart);
|
||||
void setBatchSize(NS::UInteger batchSize);
|
||||
};
|
||||
|
||||
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
||||
public:
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType);
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType);
|
||||
};
|
||||
|
||||
class Vector : public NS::Referencing<Vector> {
|
||||
public:
|
||||
static class Vector* alloc();
|
||||
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class MatrixVectorMultiplication
|
||||
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixVectorMultiplication* alloc();
|
||||
|
||||
MatrixVectorMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector);
|
||||
};
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
rowBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
matrices,
|
||||
rowBytes,
|
||||
matrixBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
||||
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::alloc() {
|
||||
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return Object::sendMessage<Matrix*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
const MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::String* Kernel::label() const {
|
||||
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
||||
}
|
||||
|
||||
_MTL_INLINE MTL::Device* Kernel::device() const {
|
||||
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_),
|
||||
device,
|
||||
transposeLeft,
|
||||
transposeRight,
|
||||
resultRows,
|
||||
resultColumns,
|
||||
interiorColumns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
||||
commandBuffer,
|
||||
leftMatrix,
|
||||
rightMatrix,
|
||||
resultMatrix);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
||||
length,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
||||
length,
|
||||
vectors,
|
||||
vectorBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::alloc() {
|
||||
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return Object::sendMessage<Vector*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
const MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixVectorMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixVectorMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
||||
device,
|
||||
transpose,
|
||||
rows,
|
||||
columns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
||||
commandBuffer,
|
||||
inputMatrix,
|
||||
inputVector,
|
||||
resultVector);
|
||||
}
|
||||
|
||||
} // namespace MPS
|
||||
604
mlx/backend/metal/primitives.cpp
Normal file
604
mlx/backend/metal/primitives.cpp
Normal file
@@ -0,0 +1,604 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void unary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string tname = type_to_name(in);
|
||||
std::string opt_name = contig ? "v" : "g";
|
||||
auto kernel = d.get_kernel(opt_name + op + tname);
|
||||
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
if (!contig) {
|
||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||
compute_encoder->setBytes(
|
||||
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
||||
int ndim = in.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
}
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "abs");
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
enc->setBytes(&start, sizeof(T), 0);
|
||||
T step = next - start;
|
||||
enc->setBytes(&step, sizeof(T), 1);
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_: // unsupported
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bool");
|
||||
case uint8:
|
||||
arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint16:
|
||||
arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint32:
|
||||
arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint64:
|
||||
arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int8:
|
||||
arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int16:
|
||||
arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int32:
|
||||
arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int64:
|
||||
arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case float16:
|
||||
arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case float32:
|
||||
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
|
||||
case complex64:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccos");
|
||||
}
|
||||
|
||||
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccosh");
|
||||
}
|
||||
|
||||
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsin");
|
||||
}
|
||||
|
||||
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsinh");
|
||||
}
|
||||
|
||||
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctanh");
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin_";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
std::vector<size_t> in_strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
std::vector<size_t> out_strides = out.strides();
|
||||
size_t axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
}
|
||||
in_strides.erase(in_strides.begin() + axis_);
|
||||
shape.erase(shape.begin() + axis_);
|
||||
size_t ndim = shape.size();
|
||||
|
||||
// ArgReduce
|
||||
int simd_size = 32;
|
||||
int n_reads = 4;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
(axis_size + n_reads - 1) / n_reads,
|
||||
kernel->maxTotalThreadsPerThreadgroup());
|
||||
// round up to the closest number divisible by simd_size
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cos");
|
||||
}
|
||||
|
||||
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cosh");
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erf");
|
||||
}
|
||||
|
||||
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erfinv");
|
||||
}
|
||||
|
||||
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op(inputs, out, "log");
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op(inputs, out, "log2");
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op(inputs, out, "log10");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "log1p");
|
||||
}
|
||||
|
||||
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy_gpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
bool odd = out_per_key % 2;
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits";
|
||||
auto kernel = d.get_kernel(kname);
|
||||
|
||||
// organize into grid nkeys x elem_per_key
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, keys, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||
|
||||
if (!keys.flags().row_contiguous) {
|
||||
int ndim = keys.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
keys.shape().data(), keys.ndim() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
}
|
||||
|
||||
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sign");
|
||||
}
|
||||
|
||||
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sin");
|
||||
}
|
||||
|
||||
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sinh");
|
||||
}
|
||||
|
||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "square");
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (recip_) {
|
||||
unary_op(inputs, out, "rsqrt");
|
||||
} else {
|
||||
unary_op(inputs, out, "sqrt");
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tan");
|
||||
}
|
||||
|
||||
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tanh");
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
130
mlx/backend/metal/scan.cpp
Normal file
130
mlx/backend/metal/scan.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
if (in.strides()[axis_] == 1) {
|
||||
kname << "contiguous_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int elements_per_simd = n_reads * 32;
|
||||
int thread_groups = in.size() / size;
|
||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (size < n_reads * 1024) {
|
||||
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
} else if (size < n_reads * 2048) {
|
||||
thread_group_size =
|
||||
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
}
|
||||
thread_group_size = std::min(
|
||||
thread_group_size,
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int tile_x = 32;
|
||||
int tile_y = 32;
|
||||
int elements_per_tile_x = tile_x * n_reads;
|
||||
int grid_y = in.size() / size / stride;
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (copies.size() > 0) {
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
167
mlx/backend/metal/utils.h
Normal file
167
mlx/backend/metal/utils.h
Normal file
@@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
MTL::ArgumentEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicity
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
tname = "bool_";
|
||||
break;
|
||||
case uint8:
|
||||
tname = "uint8";
|
||||
break;
|
||||
case uint16:
|
||||
tname = "uint16";
|
||||
break;
|
||||
case uint32:
|
||||
tname = "uint32";
|
||||
break;
|
||||
case uint64:
|
||||
tname = "uint64";
|
||||
break;
|
||||
case int8:
|
||||
tname = "int8";
|
||||
break;
|
||||
case int16:
|
||||
tname = "int16";
|
||||
break;
|
||||
case int32:
|
||||
tname = "int32";
|
||||
break;
|
||||
case int64:
|
||||
tname = "int64";
|
||||
break;
|
||||
case float16:
|
||||
tname = "float16";
|
||||
break;
|
||||
case float32:
|
||||
tname = "float32";
|
||||
break;
|
||||
case bfloat16:
|
||||
tname = "bfloat16";
|
||||
break;
|
||||
case complex64:
|
||||
tname = "complex64";
|
||||
break;
|
||||
}
|
||||
return tname;
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (xs[0].ndim() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||
bool contiguous = true;
|
||||
for (auto& x : xs) {
|
||||
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < xs.size(); j++) {
|
||||
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
18
mlx/backend/no_metal/metal.cpp
Normal file
18
mlx/backend/no_metal/metal.cpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
void new_stream(Stream) {}
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
Reference in New Issue
Block a user