angelos's commit files

This commit is contained in:
Angelos Katharopoulos
2023-11-29 10:42:59 -08:00
parent 8ca7f9e8e9
commit d1f86272a2
56 changed files with 12350 additions and 0 deletions

View File

@@ -0,0 +1,323 @@
#include <cassert>
#include <limits>
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x) {
x *= 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
ipart = simd::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x;
}
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
int16x8_t epart = vcvtq_s16_f16(ipart);
epart = vaddq_s16(epart, vdupq_n_s16(15));
epart = vshlq_n_s16(epart, 10);
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
}
/**
* Implementation of folding maximum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_max(float16x8_t x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
/**
* Implementation of folding sum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_add(float16x8_t x) {
float16x4_t y;
float16x4_t zero = vdup_n_f16(0);
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
y = vpadd_f16(y, zero);
y = vpadd_f16(y, zero);
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
};
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
};
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
// Compute the normalizer and the exponentials
VT vnormalizer = ops.init(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
s = M;
while (s >= N) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);
break;
case complex64:
eval(inputs, out);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,216 @@
#include <cassert>
#include <cmath>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, U, Op> opsv(op);
DefaultVectorScalar<T, U, Op> opvs(op);
DefaultVectorVector<T, U, Op> opvv(op);
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
}
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
comparison_op<bool, bool>(a, b, out, op);
break;
case uint8:
comparison_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
comparison_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
comparison_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
comparison_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
comparison_op<int8_t, bool>(a, b, out, op);
break;
case int16:
comparison_op<int16_t, bool>(a, b, out, op);
break;
case int32:
comparison_op<int32_t, bool>(a, b, out, op);
break;
case int64:
comparison_op<int64_t, bool>(a, b, out, op);
break;
case float16:
comparison_op<float16_t, bool>(a, b, out, op);
break;
case float32:
comparison_op<float, bool>(a, b, out, op);
break;
case bfloat16:
comparison_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
comparison_op<complex64_t, bool>(a, b, out, op);
break;
}
}
} // namespace
void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
return x == y || (std::isnan(x) && std::isnan(y));
});
} else {
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto op = [](auto x, auto y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = (x > y) ? x : y;
auto minval = (x > y) ? y : x;
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(std::exp(minval - maxval)));
};
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, op);
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, op);
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, op);
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
}
}
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
}
struct PowerFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
if (exp < 0) {
throw std::invalid_argument(
"Integers cannot be raise to negative powers");
}
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, PowerFn{});
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
} // namespace mlx::core

554
mlx/backend/common/binary.h Normal file
View File

@@ -0,0 +1,554 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
enum BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
VectorVector,
General,
};
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = VectorVector;
} else {
bopt = General;
}
return bopt;
}
void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
switch (bopt) {
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case VectorScalar:
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
struct UseDefaultBinaryOp {
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
Op op;
DefaultVectorScalar(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultScalarVector {
Op op;
DefaultScalarVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorVector {
Op op;
DefaultVectorVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op>
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
dst += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(const array& a, const array& b, array& out, Ops... ops) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, out, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, out, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, out, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, out, ops...);
break;
case int8:
binary_op<int8_t>(a, b, out, ops...);
break;
case int16:
binary_op<int16_t>(a, b, out, ops...);
break;
case int32:
binary_op<int32_t>(a, b, out, ops...);
break;
case int64:
binary_op<int64_t>(a, b, out, ops...);
break;
case float16:
binary_op<float16_t>(a, b, out, ops...);
break;
case float32:
binary_op<float>(a, b, out, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, out, ops...);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -0,0 +1,85 @@
#include <numeric>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end());
for (auto& s : strides_in) {
s *= in.itemsize();
}
std::vector<std::ptrdiff_t> strides_out(
out.strides().begin(), out.strides().end());
for (auto& s : strides_out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
} else {
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
}
float scale = 1.0f;
if (inverse_) {
size_t nelem = std::accumulate(
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
return x * shape[y];
});
scale /= nelem;
}
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,98 @@
#include <cassert>
#include <cmath>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int N = in.shape().back();
int M = in.data_size() / N;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum
current_in_ptr = in_ptr;
T maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
}
// Compute the normalizer and the exponentials
T normalizer = 0;
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
T expv = std::exp(*current_in_ptr - maximum);
normalizer += expv;
*current_out_ptr = expv;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) {
*current_out_ptr *= normalizer;
}
}
}
} // namespace
void Softmax::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides().back() == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float>(in, out);
break;
case float16:
softmax<float16_t>(in, out);
break;
case bfloat16:
softmax<bfloat16_t>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,200 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
#include <cstdlib>
namespace mlx::core {
namespace allocator {
Allocator& allocator() {
return metal::allocator();
}
void* Buffer::raw_ptr() {
return static_cast<MTL::Buffer*>(ptr_)->contents();
}
} // namespace allocator
namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device),
head_(nullptr),
tail_(nullptr),
pool_size_(0),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
BufferCache::~BufferCache() {
clear();
}
void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
holder->buf->release();
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
auto it = buffer_pool_.lower_bound(size);
// Make sure we use > 50% of the available memory
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
if (min_bytes_to_free >= 0.9 * pool_size_) {
size_t old_pool_size = pool_size_;
clear();
return old_pool_size;
} else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
tail_->buf->release();
tail_->buf = nullptr;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return total_bytes_freed;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove)
return;
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
Buffer MetalAllocator::malloc(size_t size) {
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
}
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
// Prepare to allocate new memory as needed
if (!buf) {
// If we are under very high memoory pressure, we don't allocate further
if (device_->currentAllocatedSize() >= block_limit_) {
return Buffer{nullptr};
}
// If we are still under memory pressure, try cleaning cache
if (buffer_cache_.can_garbage_collect()) {
buffer_cache_.release_cached_buffers(size);
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
buf = device_->newBuffer(size, res_opt);
}
peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize());
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
buffer_cache_.recycle_to_cache(buf);
}
MetalAllocator& allocator() {
static MetalAllocator allocator_;
return allocator_;
}
} // namespace metal
} // namespace mlx::core

View File

@@ -0,0 +1,76 @@
#pragma once
#include <map>
#include <mutex>
#include <vector>
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core::metal {
using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
size_t release_cached_buffers(size_t min_bytes_to_free);
bool can_garbage_collect() {
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
}
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
size_t gc_limit_;
};
} // namespace
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
private:
MTL::Device* device_;
MetalAllocator();
friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
// Allocation stats
size_t peak_allocated_size_;
size_t block_limit_;
};
MetalAllocator& allocator();
} // namespace mlx::core::metal

555
mlx/backend/metal/conv.cpp Normal file
View File

@@ -0,0 +1,555 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
void explicit_gemm_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<1>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(2),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1)},
/* const int wS[NDIM] = */ {wt.shape(1)},
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int dil[NDIM] = */ {wt_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
};
// Direct to explicit gemm conv
if (wt_dilation[0] == 1) {
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void explicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N,
conv_params.oS[0],
conv_params.oS[1],
conv_params.wS[0],
conv_params.wS[1],
conv_params.C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[2] * conv_params.str[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
void winograd_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
copies_w.push_back(zero_arr);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
/* const int C = */ in_padded.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int dil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
};
int O_c = conv_params.O;
int C_c = conv_params.C;
int N_tiles_n = conv_params.N;
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
mlx_matmul(
s,
d,
/*a = */ inp_wg,
/*b = */ filt_wg,
/*c = */ out_wg,
/*M = */ N_tiles,
/*N = */ conv_params.O,
/*K = */ conv_params.C,
/*batch_size_out = */ 8 * 8,
/*a_cols = */ conv_params.C,
/*b_cols = */ conv_params.O,
/*a_transposed = */ false,
/*b_transposed = */ false,
/*copies = */ empty_copies);
}
// Do output transform
{
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out_wg, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
};
// Direct to winograd conv
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
conv_params.dil[1] == 1) {
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
if (!wt.flags().row_contiguous) {
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
}
// 2D conv
if (out.ndim() == 4) {
conv_2D_gpu(
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
}
// 1D conv
else if (out.ndim() == 3) {
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {
throw std::invalid_argument(
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
}
// Clear copies
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core

16
mlx/backend/metal/copy.h Normal file
View File

@@ -0,0 +1,16 @@
#pragma once
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
namespace mlx::core {
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace(
const array& src,
array& out,
CopyType ctype,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,30 @@
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint64, uint64_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int16, int16_t)
instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,208 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename U>
struct IndexValPair {
uint32_t index;
U val;
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
};
template <typename U>
struct ArgMin {
static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset+i;
}
}
return best;
}
};
template <typename U>
struct ArgMax {
static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset+i;
}
}
return best;
}
};
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
);
}
template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]],
device uint32_t *out [[buffer(1)]],
const device int *shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size.
//
// Note: in shape == out shape with this convention.
//
// The sketch of the kernel is as follows.
// 1. Launch prod(shape) * thread_group_size threads.
// 2. Loop ceildiv(axis_size / lsize) times
// 3. Read input values
// 4. Reduce among them and go to 3
// 4. Reduce in each simd_group
// 6. Write in the thread local memory
// 6. Reduce them accross thread group
// 7. Write the output without need for atomic
Op op;
// Compute the input/output index. There is one beginning and one output for
// the whole threadgroup.
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
IndexValPair<T> best(0, Op::init);
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
// Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
for (int i=0; i<N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++;
current_in += axis_stride;
}
best = op.template reduce_many<N_READS>(best, vals, offset);
}
// At this point we have reduced the axis into thread group best values so we
// need to reduce across the thread group.
// First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
// Write to the threadgroup memory
if (simd_lane_id == 0) {
local_data[simd_group_id] = best;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id != 0) {
return;
}
// Read the appropriate value from local data and perform one simd reduction
uint simd_groups = ceildiv(lsize, simd_size);
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
for (uint offset=simd_size/2; offset>0; offset/=2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
// Finally write the output
if (lid == 0) {
out[out_idx] = best.index;
}
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t)
instantiate_arg_reduce(uint16, uint16_t)
instantiate_arg_reduce(uint32, uint32_t)
instantiate_arg_reduce(uint64, uint64_t)
instantiate_arg_reduce(int8, int8_t)
instantiate_arg_reduce(int16, int16_t)
instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,392 @@
#pragma once
#include "mlx/backend/metal/kernels/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
/*
Following the Metal Shading Language Specification (Metal 3.1)
"bfloat is an extended itypeing point type that only allows implicit conversion
to a type of greater itypeing point rank. While bfloat can be implicitly
converted to itype, it cannot be implicitly converted to half, and neither
itype nor half can be implicitly converted to bfloat."
Further, as far as I can tell, the stdlib math/simd functions are not defined
for bfloat and calling with an argument of type bfloat will result in that
argument getting implicitly converted to itype which then returns an output
that is (likely) a itype which cannot be implicitly converted into a bfloat
This leads to situations where
bfloat a = 5.0bf;
bfloat b = metal::abs(a); // this will throw an error since abs return itype
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
For the moment, I will be adding overloaded instantiations of the math
functions to accordingly automatically handle the casting
*/
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
\
METAL_FUNC otype abs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acos(itype x) { \
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acosh(itype x) { \
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asin(itype x) { \
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asinh(itype x) { \
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atan(itype y_over_x) { \
return static_cast<otype>( \
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
} \
METAL_FUNC otype atan2(itype y, itype x) { \
return static_cast<otype>( \
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atanh(itype x) { \
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype ceil(itype x) { \
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cos(itype x) { \
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cosh(itype x) { \
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cospi(itype x) { \
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype divide(itype x, itype y) { \
return static_cast<otype>( \
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype exp(itype x) { \
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp10(itype x) { \
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp2(itype x) { \
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fabs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fdim(itype x, itype y) { \
ctype t = static_cast<ctype>(x - y); \
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
} \
METAL_FUNC otype floor(itype x) { \
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fma(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fma( \
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
} \
METAL_FUNC otype fmax(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmin(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmod(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fract(itype x) { \
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype frexp(itype x, thread int& exp) { \
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
} \
METAL_FUNC otype ldexp(itype x, int k) { \
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
} \
METAL_FUNC otype log(itype x) { \
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log10(itype x) { \
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log2(itype x) { \
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype max(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype max3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype median3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype min(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype min3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype nextafter(itype x, itype y) { \
return static_cast<otype>( \
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
} \
METAL_FUNC otype pow(itype x, itype y) { \
return static_cast<otype>( \
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype powr(itype x, itype y) { \
return static_cast<otype>( \
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype rint(itype x) { \
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype round(itype x) { \
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype rsqrt(itype x) { \
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sin(itype x) { \
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinh(itype x) { \
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinpi(itype x) { \
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sqrt(itype x) { \
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tan(itype x) { \
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanh(itype x) { \
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanpi(itype x) { \
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype trunc(itype x) { \
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
}
namespace metal {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_MAYBE_FAST_MATH__);
namespace fast {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_FAST_MATH__);
} // namespace fast
namespace precise {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_PRECISE_MATH__);
} // namespace precise
} // namespace metal
///////////////////////////////////////////////////////////////////////////////
// Metal simd for bfloat16
///////////////////////////////////////////////////////////////////////////////
#define instantiate_metal_simd_comm_funcs( \
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
\
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
return ctype_to_otype( \
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
return ctype_to_otype( \
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
return ctype_to_otype( \
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
}
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
\
METAL_FUNC otype simd_max(itype data) { \
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_min(itype data) { \
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_product(itype data) { \
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_sum(itype data) { \
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_xor(itype data) { \
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(
bfloat16_t,
bfloat16_t,
uint16_t,
bfloat16_to_uint16,
uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal

View File

@@ -0,0 +1,553 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/gemm/conv.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
out += tid.z * params.out_strides[0];
in += tid.z * params.in_strides[0];
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for(int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) {
// Local in
for(int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0);
}
// Accumulate
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Implicit gemm kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
in, wt, out,
params, tgp_memory,
tid, lid, simd_gid, simd_lid
);
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S>
struct WinogradTransforms {
};
template <>
struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
MLX_MTL_CONST int FILTER_SIZE = 3;
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
{ 0.00, 0.00, 1.00},
};
};
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T,
int BC = 32,
int BO = 4,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]],
const constant int& O [[buffer(3)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix
simdgroup_matrix<T, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix
simdgroup_matrix<T, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
// Move to the correct output filter
size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O)
short ohw_0 = sm * 8 + sn;
short ohw_1 = sm * 8 + sn + 1;
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
// Prepare shared memory
threadgroup T Ws[BO][R][R][BC];
// Loop over C
for(int bc = 0; bc < C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int kh = 0; kh < R; ++kh) {
for(int kw = 0; kw < R; ++kw) {
for(int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
wt_out_1[c * O] = g_out.thread_elements()[1];
}
wt_in += BC;
wt_out_0 += BC * O;
wt_out_1 += BC * O;
}
}
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
const device itype* wt_in [[buffer(0)]],\
device itype* wt_out [[buffer(1)]],\
const constant int& C [[buffer(2)]],\
const constant int& O [[buffer(3)]],\
uint tid [[threadgroup_position_in_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BC,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
constexpr int A = WGT::IN_TILE_SIZE;
constexpr int N_SIMD_GROUPS = WM * WN;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
// Resolve input tile
constexpr int TH = (A / WM);
constexpr int TW = (A / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0]
+ bh * params.in_strides[1]
+ bw * params.in_strides[2];
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
}
}
// inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory
threadgroup T Is[A][A][BC];
// Loop over C
for(int bc = 0; bc < params.C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for(int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0];
inp_out_1[c] = I_out.thread_elements()[1];
}
inp_in += BC;
inp_out_0 += BC;
inp_out_1 += BC;
}
}
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
const device itype* inp_in [[buffer(0)]],\
device itype* inp_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BO,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
constexpr int N_SIMD_GROUPS = WM * WN;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
// Out_in comes in shape (A x A x tiles x O)
// We do transform and then write out to out_out in shape (N, H, W, O)
// Resolve output tile
constexpr int TH = (M / WM);
constexpr int TW = (M / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
// Move to the correct input tile
out_out += tid.z * params.out_strides[0]
+ bh * params.out_strides[1]
+ bw * params.out_strides[2];
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
}
}
// out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory
threadgroup T Os[M][M][BO];
// Loop over O
for(int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0];
}
if((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read out from shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
if(jump_in[h][w] >= 0) {
device T* out_ptr = out_out + jump_in[h][w];
for(int c = simd_lane_id; c < BO; c += 32) {
out_ptr[c] = Os[kh + h][kw + w][c];
}
}
}
}
out_out += BO;
out_in_0 += BO;
out_in_1 += BO;
}
}
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
const device itype* out_in [[buffer(0)]],\
device itype* out_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half);

View File

@@ -0,0 +1,269 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src,
device U* dst,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src,
device U* dst,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
constant const size_t& dst_stride,
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
constant const size_t dst_strides[2],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
constant const size_t dst_strides[3],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
constant const size_t dst_strides[DIM],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const size_t* dst_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src, \
device otype* dst, \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
constant const size_t dst_strides[dims], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
constant const size_t& dst_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
constant const size_t dst_strides[2], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
constant const size_t dst_strides[3], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const size_t* dst_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t)
instantiate_copy_itype(uint16, uint16_t)
instantiate_copy_itype(uint32, uint32_t)
instantiate_copy_itype(uint64, uint64_t)
instantiate_copy_itype(int8, int8_t)
instantiate_copy_itype(int16, int16_t)
instantiate_copy_itype(int32, int32_t)
instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t)

View File

@@ -0,0 +1,68 @@
#pragma once
#include <metal_math>
/*
* Approximation to the error function.
* Based on code from:
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
*/
float erf(float a) {
float r, s, t, u;
t = metal::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = metal::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = metal::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = metal::fma(r, s, u);
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = metal::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - metal::exp(r);
r = metal::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = metal::fma(r, a, a);
}
return r;
}
float erfinv(float a) {
auto t = metal::fma(a, 0.0f - a, 1.0f);
t = metal::log(t);
float p;
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}

View File

@@ -0,0 +1,91 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant int &M [[buffer(3)]],
const constant int &N [[buffer(4)]],
const constant int &K [[buffer(5)]],
const constant int &batch_stride_a [[buffer(6)]],
const constant int &batch_stride_b [[buffer(7)]],
const constant int &batch_stride_c [[buffer(8)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
A, B, C,
M, N, K,
batch_stride_a, batch_stride_b, batch_stride_c,
tgp_memory,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant int &M [[buffer(3)]], \
const constant int &N [[buffer(4)]], \
const constant int &K [[buffer(5)]], \
const constant int &batch_stride_a [[buffer(6)]], \
const constant int &batch_stride_b [[buffer(7)]], \
const constant int &batch_stride_c [[buffer(8)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
// TODO: Accumulation in different type

View File

@@ -0,0 +1,99 @@
#include "mlx/backend/metal/kernels/utils.h"
static constexpr constant uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}
};
union rbits {
uint2 val;
uchar4 bytes[2];
};
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
v.val.x = count.x + ks[0];
v.val.y = count.y + ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
v.val.x += v.val.y;
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
v.val.y ^= v.val.x;
}
v.val.x += ks[(i + 1) % 3];
v.val.y += ks[(i + 2) % 3] + i + 1;
}
return v;
}
[[kernel]] void rbitsc(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
auto key = uint2(keys[kidx], keys[kidx + 1]);
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
}
}
}
[[kernel]] void rbits(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
device const int& ndim,
device const int* key_shape,
device const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
auto key = uint2(keys[k1_elem], keys[k2_elem]);
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
}
}
}

View File

@@ -0,0 +1,536 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
static constant uint8_t simd_size = 32;
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op;
threadgroup U local_vals[simd_size];
U total_val = Op::init;
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// General reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
template <typename T, typename U, typename Op, int NDIM>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
#define instantiate_general_reduce_helper(name, itype, otype, op) \
template [[host_name("general_reduce_" #name)]] \
[[kernel]] void general_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
[[kernel]] void general_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce(name, itype, otype, op) \
instantiate_general_reduce_helper(name, itype, otype, op) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& reduction_size [[buffer(2)]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Each threadgroup handles 1 reduction
in += tid * reduction_size + lid * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size];
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid == 0) {
out[tid] = total_val;
}
}
#define instantiate_row_reduce(name, itype, otype, op) \
template [[host_name("row_reduce_" #name)]] \
[[kernel]] void row_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& reduction_size [[buffer(2)]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline void _contiguous_strided_reduce(
const device T *in,
device mlx_atomic<U> *out,
threadgroup U *local_data,
uint in_idx,
uint out_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
T local_vals[N_READS];
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
if(lid.y == 0) {
U val = op.init;
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
op.atomic_update(out, val, out_idx);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
out_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_col_reduce(name, itype, otype, op) \
template [[host_name("col_reduce_" #name)]] \
[[kernel]] void col_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
const device size_t& in_dim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
template [[host_name("contiguous_strided_reduce_" #name)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
const device size_t& in_dim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_contiguous_strided_helper(name, itype, otype, op) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_reduce(name, itype, otype, op) \
instantiate_all_reduce(name, itype, otype, op) \
instantiate_row_reduce(name, itype, otype, op) \
instantiate_col_reduce(name, itype, otype, op) \
instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_general_reduce(name, itype, otype, op)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
instantiate_reduce(name ##tname, itype, otype, op)
#define instantiate_reduce_from_types(name, otype, op) \
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
// special case bool with larger output type
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
instantiate_same_reduce(sum, int8, int8_t, Sum)
instantiate_same_reduce(sum, int16, int16_t, Sum)
instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
instantiate_same_reduce(prod, int8, int8_t, Prod)
instantiate_same_reduce(prod, int16, int16_t, Prod)
instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
instantiate_init_reduce(andbool_, bool, And)
instantiate_reduce_from_types(and, bool, And)
instantiate_init_reduce(orbool_, bool, Or)
instantiate_reduce_from_types(or, bool, Or)
// Compiler segfaulted with the names "min" or "max" ...
instantiate_same_reduce(min_, uint8, uint8_t, Min)
instantiate_same_reduce(min_, uint16, uint16_t, Min)
instantiate_same_reduce(min_, uint32, uint32_t, Min)
instantiate_same_reduce(min_, int8, int8_t, Min)
instantiate_same_reduce(min_, int16, int16_t, Min)
instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
instantiate_same_reduce(max_, int8, int8_t, Max)
instantiate_same_reduce(max_, int16, int16_t, Max)
instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@@ -0,0 +1,284 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Abs {
template <typename T> T operator()(T x) { return metal::abs(x); };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
template <> complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
};
struct ArcCosh {
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
};
struct ArcSin {
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
};
struct ArcSinh {
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
};
struct ArcTan {
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
};
struct ArcTanh {
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
};
struct Cos {
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Cosh {
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Erf {
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
};
struct ErfInv {
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
};
struct Exp {
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
template <> complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Log {
template <typename T> T operator()(T x) { return metal::precise::log(x); };
};
struct Log2 {
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
};
struct Log10 {
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
};
struct Log1p {
template <typename T> T operator()(T x) { return log1p(x); };
};
struct LogicalNot {
template <typename T> T operator()(T x) { return !x; };
};
struct Negative {
template <typename T> T operator()(T x) { return -x; };
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
template <> uint32_t operator()(uint32_t x) { return x != 0; };
};
struct Sin {
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Sinh {
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Square {
template <typename T> T operator()(T x) { return x * x; };
};
struct Sqrt {
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
};
struct Rsqrt {
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
};
struct Tan {
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {
(tan_a - tanh_b * t1) / denom,
(tanh_b + tan_a * t1) / denom
};
};
};
struct Tanh {
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {
(tanh_a + tan_b * t1) / denom,
(tan_b - tanh_a * t1) / denom
};
};
};
template <typename T, typename Op>
[[kernel]] void unary_op_v(
device const T* in,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
[[kernel]] void unary_op_g(
device const T* in,
device T* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
out[index] = Op()(in[idx]);
}
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op)
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
instantiate_unary_all(name, uint16, uint16_t, op) \
instantiate_unary_all(name, uint32, uint32_t, op) \
instantiate_unary_all(name, uint64, uint64_t, op) \
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op)
instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh)
instantiate_unary_float(arcsin, ArcSin)
instantiate_unary_float(arcsinh, ArcSinh)
instantiate_unary_float(arctan, ArcTan)
instantiate_unary_float(arctanh, ArcTanh)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)
instantiate_unary_float(log10, Log10)
instantiate_unary_float(log1p, Log1p)
instantiate_unary_types(neg, Negative)
instantiate_unary_float(sigmoid, Sigmoid)
instantiate_unary_float(erf, Erf)
instantiate_unary_float(erfinv, ErfInv)
instantiate_unary_types(sign, Sign)
instantiate_unary_float(sin, Sin)
instantiate_unary_float(sinh, Sinh)
instantiate_unary_types(square, Square)
instantiate_unary_float(sqrt, Sqrt)
instantiate_unary_float(rsqrt, Rsqrt)
instantiate_unary_float(tan, Tan)
instantiate_unary_float(tanh, Tanh)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)
instantiate_unary_all(neg, complex64, complex64_t, Negative)
instantiate_unary_all(sin, complex64, complex64_t, Sin)
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(lnot, bool_, bool, LogicalNot)

View File

@@ -0,0 +1,369 @@
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
//////////////////////////////////////////////////////////////////////
// Case wise reduce dispatch
//////////////////////////////////////////////////////////////////////
namespace {
// All Reduce
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
// Get kernel and encode buffers
size_t in_size = in.size();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// atleast n_reads inputs
int n_reads = REDUCE_N_READS;
// mod_in_size gives us the groups of n_reads needed to go over the entire
// input
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups =
(mod_in_size + thread_group_size - 1) / thread_group_size;
n_thread_groups = std::min(n_thread_groups, 1024u);
uint nthreads = n_thread_groups * thread_group_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void row_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
int n_reads = REDUCE_N_READS;
size_t reduction_size = in.size() / out.size();
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void col_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
std::ostringstream kernel_name;
bool encode_in_shape = false;
bool encode_ndim = false;
// If the slowest moving axis can be merged into the reductions,
// we call the column reduce kernel
// In this case, a linear index in the output corresponds to the
// linear index in the input where the reduction starts
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
kernel_name << "col_reduce_" << op_name << type_to_name(in);
}
// Otherwise, while all the reduction axes can be merged, the mapping between
// indices in the output and input require resolving using shapes and strides
else {
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
encode_in_shape = true;
// We check for a viable template with the required number of dimensions
// we only care about encoding non-reduced shapes and strides in the input
size_t non_reducing_dims = in.ndim() - axes_.size();
if (non_reducing_dims >= 1 &&
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << non_reducing_dims;
} else {
encode_ndim = true;
}
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t out_size = out.size();
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
// Calculate the number of inputs to reduce and the stride b/w them
size_t reduction_size = 1;
size_t in_ndim = in.ndim();
size_t reduction_stride = in_size;
for (int i : axes_) {
reduction_size *= in.shape(i);
reduction_stride = std::min(reduction_stride, in.strides()[i]);
}
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
if (encode_in_shape) {
// Obtain the non-reducing shape and strides of the input to encode
std::vector<int> inp_shape_mod;
std::vector<size_t> inp_strides_mod;
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
inp_shape_mod.push_back(in.shape(i));
inp_strides_mod.push_back(in.strides()[i]);
}
}
size_t ndim = inp_shape_mod.size();
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
}
}
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output =
(reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread;
// We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumlates for that output
uint threadgroup_dim_x = std::min(out_size, 128ul);
uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void general_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
bool encode_ndim = true;
std::ostringstream kernel_name;
kernel_name << "general_reduce_" << op_name << type_to_name(in);
// Check for specialzed kernels for input ndim
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << in.ndim();
encode_ndim = false;
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t ndim = in.ndim();
// We set the reducing strides to 0 to induce collisions for the reduction
std::vector<size_t> out_strides(ndim);
size_t stride = 1;
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
if (j >= 0 && axes_[j] == i) {
out_strides[i] = 0;
--j;
} else {
out_strides[i] = stride;
stride *= in.shape(i);
}
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > in_size) {
thread_group_size = in_size;
}
size_t nthreads = in_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
//////////////////////////////////////////////////////////////////////
// Main reduce dispatch
//////////////////////////////////////////////////////////////////////
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here
assert(!axes_.empty());
// Continue with reduction operation
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::string op_name;
switch (reduce_type_) {
case Reduce::And:
op_name = "and";
break;
case Reduce::Or:
op_name = "or";
break;
case Reduce::Sum:
op_name = "sum";
break;
case Reduce::Prod:
op_name = out.dtype() == bool_ ? "and" : "prod";
break;
case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min_";
break;
case Reduce::Max:
op_name = out.dtype() == bool_ ? "or" : "max_";
break;
}
// Initialize output
auto& s = stream();
auto& d = metal::device(s.device);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out, 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
// Reduce
{
// Check for contiguous data
if (in.size() == in.data_size() &&
(in.flags().row_contiguous || in.flags().col_contiguous)) {
// Go to all reduce if reducing over all axes
if (axes_.size() == in.ndim()) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
return;
}
// Use specialized kernels if the input is row contiguous and
// the reducing axes can be merged into one
else if (
in.flags().row_contiguous && in.strides().back() == 1 &&
(axes_.back() - axes_.front()) == axes_.size() - 1) {
// If the fastest moving axis is being reduced, go to row reduce
if (axes_[0] == (in.ndim() - axes_.size())) {
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
// Otherwise go to to generalized strided reduce
// Note: bool isn't support here yet due to the use of atomics
// once that is updated, this should be the else condition of this
// branch
else if (in.dtype() != bool_) {
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
}
}
// Fall back to the general case
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,15 @@
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Allocator& allocator() {
static CommonAllocator allocator_;
return allocator_;
}
void* Buffer::raw_ptr() {
return ptr_;
}
} // namespace mlx::core::allocator