feat: metal formatting and pre-commit bump (#1038)

* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
This commit is contained in:
Nripesh Niketan 2024-04-30 18:18:09 +04:00 committed by GitHub
parent 8db7161c94
commit a30e7ed2da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 3822 additions and 3337 deletions

View File

@ -1,11 +1,11 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3 rev: v18.1.4
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0 rev: 24.4.2
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort

View File

@ -36,8 +36,8 @@ template <typename T>
} }
#define instantiate_axpby(type_name, type) \ #define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \ template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
[[kernel]] void axpby_general<type>( \ axpby_general<type>( \
device const type* x [[buffer(0)]], \ device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \ device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \ device type* out [[buffer(2)]], \
@ -48,8 +48,8 @@ template <typename T>
constant const size_t* y_strides [[buffer(7)]], \ constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \ constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \ template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
[[kernel]] void axpby_contiguous<type>( \ axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \ device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \ device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \ device type* out [[buffer(2)]], \

View File

@ -12,13 +12,13 @@ template <typename T>
} }
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \ template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
[[kernel]] void arange<type>( \
constant const type& start, \ constant const type& start, \
constant const type& step, \ constant const type& step, \
device type* out, \ device type* out, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t) instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t) instantiate_arange(uint32, uint32_t)
@ -29,4 +29,4 @@ instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t) instantiate_arange(int64, int64_t)
instantiate_arange(float16, half) instantiate_arange(float16, half)
instantiate_arange(float32, float) instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t) instantiate_arange(bfloat16, bfloat16_t) // clang-format on

View File

@ -18,7 +18,8 @@ struct ArgMin {
static constexpr constant U init = Limits<U>::max; static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) { IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) { if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current; return current;
} else { } else {
return best; return best;
@ -26,7 +27,8 @@ struct ArgMin {
} }
template <int N> template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) { IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] < best.val) { if (vals[i] < best.val) {
best.val = vals[i]; best.val = vals[i];
@ -42,7 +44,8 @@ struct ArgMax {
static constexpr constant U init = Limits<U>::min; static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) { IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) { if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current; return current;
} else { } else {
return best; return best;
@ -50,7 +53,8 @@ struct ArgMax {
} }
template <int N> template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) { IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] > best.val) { if (vals[i] > best.val) {
best.val = vals[i]; best.val = vals[i];
@ -64,12 +68,9 @@ struct ArgMax {
template <typename U> template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{ return IndexValPair<U>{
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
simd_shuffle_down(data.val, delta)
};
} }
template <typename T, typename Op, int N_READS> template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general( [[kernel]] void arg_reduce_general(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
uint simd_size [[threads_per_simdgroup]], uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size // Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size. // and stride are provided in axis_stride and axis_size.
// //
@ -161,8 +161,8 @@ template <typename T, typename Op, int N_READS>
} }
#define instantiate_arg_reduce_helper(name, itype, op) \ #define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \ arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \ device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \ const device int* shape [[buffer(2)]], \
@ -178,6 +178,7 @@ template <typename T, typename Op, int N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
#define instantiate_arg_reduce(name, itype) \ #define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float) instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t) instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides); auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]); c[out_idx] = Op()(a[idx.x], b[idx.y]);
} }
@ -113,16 +115,16 @@ template <typename T, typename U, typename Op>
} }
#define instantiate_binary(name, itype, otype, op, bopt) \ #define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \ template \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \ [[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \ binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -133,16 +135,16 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \ #define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \ binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \ binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -150,8 +152,8 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \ binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -162,10 +164,8 @@ template <typename T, typename U, typename Op>
instantiate_binary_g_dim(name, itype, otype, op, 4) \ instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5) instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \ #define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -176,14 +176,16 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \ #define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \ instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
// clang-format off
#define instantiate_binary_integer(name, op) \ #define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
@ -192,19 +194,22 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, int8, int8_t, int8_t, op) \ instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \ instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \ instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
// clang-format off
#define instantiate_binary_float(name, op) \ #define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \ instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \ instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
// clang-format off
#define instantiate_binary_types(name, op) \ #define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \ instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op) instantiate_binary_float(name, op) // clang-format on
// clang-format off
#define instantiate_binary_types_bool(name, op) \ #define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \ instantiate_binary_all(name, uint8, uint8_t, bool, op) \
@ -218,8 +223,9 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, float16, half, bool, op) \ instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \ instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op) instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
// clang-format off
instantiate_binary_types(add, Add) instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide) instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal) instantiate_binary_types_bool(eq, Equal)
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor) instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor) instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift) instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) instantiate_binary_integer(right_shift, RightShift) // clang-format on

View File

@ -3,23 +3,37 @@
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct FloorDivide { struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; } template <typename T>
template <> float operator()(float x, float y) { return trunc(x / y); } T operator()(T x, T y) {
template <> half operator()(half x, half y) { return trunc(x / y); } return x / y;
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); } }
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
}; };
struct Remainder { struct Remainder {
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) { metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y; return x % y;
} }
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) { metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y; auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) { if (r != 0 && (r < 0 != y < 0)) {
r += y; r += y;
@ -34,7 +48,8 @@ struct Remainder {
} }
return r; return r;
} }
template <> complex64_t operator()(complex64_t x, complex64_t y) { template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y; return x % y;
} }
}; };
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
d[index] = Op2()(a[0], b[0]); d[index] = Op2()(a[0], b[0]);
} }
template <typename T, typename U, typename Op1, typename Op2> template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss( [[kernel]] void binary_op_ss(
device const T* a, device const T* a,
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]); c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]); d[out_idx] = Op2()(a[a_idx], b[b_idx]);
} }
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides); auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]); c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]); d[out_idx] = Op2()(a[idx.x], b[idx.y]);
} }
@ -180,8 +196,8 @@ template <typename T, typename U, typename Op1, typename Op2>
} }
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \ #define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \ binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -189,8 +205,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \ binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -201,9 +217,10 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ #define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \ binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -211,8 +228,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \ binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -221,8 +238,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \ binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -232,12 +249,11 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
#define instantiate_binary_g(name, itype, otype, op1, op2) \ #define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \ binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -249,19 +265,22 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ #define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
// clang-format off
#define instantiate_binary_float(name, op1, op2) \ #define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \ instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \ instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
// clang-format off
#define instantiate_binary_types(name, op1, op2) \ #define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
@ -275,4 +294,4 @@ template <typename T, typename U, typename Op1, typename Op2>
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2) instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder) instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on

View File

@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
@ -23,12 +21,13 @@ template <typename T, int N>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]], const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C; int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i]; for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1; int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out // Set out
out += gid.z * filter_size + gid.y * (params->C); out += gid.z * filter_size + gid.y * (params->C);
@ -85,12 +84,13 @@ template <typename T, int N>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]], const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C; int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i]; for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1; int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out // Set out
out += gid.z * filter_size + gid.x * (filter_size / params->C); out += gid.z * filter_size + gid.x * (filter_size / params->C);
@ -142,23 +142,23 @@ template <typename T, int N>
} }
#define instantiate_naive_unfold_nd(name, itype, n) \ #define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \ template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
[[kernel]] void naive_unfold_Nd( \ naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \ device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \ const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \ uint3 gid [[thread_position_in_grid]]); \
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \ template \
[[kernel]] void naive_unfold_transpose_Nd( \ [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \ device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \ const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \ #define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \ instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
instantiate_naive_unfold_nd(name, itype, 2) \ name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(float16, half);
@ -168,7 +168,8 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
/// Slow and naive conv2d kernels /// Slow and naive conv2d kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
typename T,
const int BM, /* Threadgroup rows (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */ const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
@ -183,7 +184,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid; (void)simd_gid;
(void)simd_lid; (void)simd_lid;
@ -202,7 +202,6 @@ template <typename T,
out_w[m] = mm % params.oS[1]; out_w[m] = mm % params.oS[1];
} }
T in_local[TM]; T in_local[TM];
T wt_local[TN]; T wt_local[TN];
T out_local[TM * TN] = {T(0)}; T out_local[TM * TN] = {T(0)};
@ -210,22 +209,24 @@ template <typename T,
for (int h = 0; h < params.wS[0]; ++h) { for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) { for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) { for (int c = 0; c < params.C; ++c) {
// Local in // Local in
for (int m = 0; m < TM; m++) { for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0); in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
} }
// Load weight // Load weight
for (int n = 0; n < TN; ++n) { for (int n = 0; n < TN; ++n) {
int o = out_o + n; int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] + wt_local[n] = o < params.O
h * params.wt_strides[1] + ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0); w * params.wt_strides[2] + c]
: T(0);
} }
// Accumulate // Accumulate
@ -234,26 +235,27 @@ template <typename T,
out_local[m * TN + n] += in_local[m] * wt_local[n]; out_local[m * TN + n] += in_local[m] * wt_local[n];
} }
} }
} }
} }
} }
for (int m = 0; m < TM; ++m) { for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) { for (int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O) if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] + out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n]; out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
} }
} }
} }
// Instantiations // Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ #define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \ "_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \ const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \ device itype* out [[buffer(2)]], \
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S> template <int M, int R, int S>
struct WinogradTransforms { struct WinogradTransforms {};
};
template <> template <>
struct WinogradTransforms<6, 3, 8> { struct WinogradTransforms<6, 3, 8> {
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T, template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
int BC = 32, [[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
int BO = 4, winograd_conv_2d_weight_transform(
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]], const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]], device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]], const constant int& C [[buffer(2)]],
@ -337,7 +334,6 @@ template <typename T,
uint tid [[threadgroup_position_in_grid]], uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup // Get lane position in simdgroup
@ -384,8 +380,10 @@ template <typename T,
// Do transform and store the result // Do transform and store the result
for (int c = 0; c < BC; ++c) { for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g; simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); g.thread_elements()[0] =
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt; simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0]; wt_out_0[c * O] = g_out.thread_elements()[0];
@ -396,12 +394,12 @@ template <typename T,
wt_out_0 += BC * O; wt_out_0 += BC * O;
wt_out_1 += BC * O; wt_out_1 += BC * O;
} }
} }
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ #define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\ template [[host_name("winograd_conv_2d_weight_transform_" #name \
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\ "_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_weight_transform<itype, bc>( \
const device itype* wt_in [[buffer(0)]], \ const device itype* wt_in [[buffer(0)]], \
device itype* wt_out [[buffer(1)]], \ device itype* wt_out [[buffer(1)]], \
const constant int& C [[buffer(2)]], \ const constant int& C [[buffer(2)]], \
@ -410,13 +408,9 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T, template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
int BC, [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
int WM, winograd_conv_2d_input_transform(
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]], const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]], device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]], const constant MLXConvParams<2>& params [[buffer(2)]],
@ -425,7 +419,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]], uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid; (void)lid;
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
@ -456,9 +449,8 @@ template <typename T,
int bw = M * tid.x + kw; int bw = M * tid.x + kw;
// Move to the correct input tile // Move to the correct input tile
inp_in += tid.z * params.in_strides[0] inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
+ bh * params.in_strides[1] bw * params.in_strides[2];
+ bw * params.in_strides[2];
// Pre compute strides // Pre compute strides
int jump_in[TH][TW]; int jump_in[TH][TW];
@ -471,11 +463,14 @@ template <typename T,
// inp_out is stored interleaved (A x A x tiles x C) // inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn; size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1; size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; device T* inp_out_0 =
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 =
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory // Prepare shared memory
threadgroup T Is[A][A][BC]; threadgroup T Is[A][A][BC];
@ -509,12 +504,12 @@ template <typename T,
inp_out_0 += BC; inp_out_0 += BC;
inp_out_1 += BC; inp_out_1 += BC;
} }
} }
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ #define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\ template [[host_name("winograd_conv_2d_input_transform_" #name \
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\ "_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
const device itype* inp_in [[buffer(0)]], \ const device itype* inp_in [[buffer(0)]], \
device itype* inp_out [[buffer(1)]], \ device itype* inp_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \ const constant MLXConvParams<2>& params [[buffer(2)]], \
@ -524,13 +519,9 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T, template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
int BO, [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
int WM, winograd_conv_2d_output_transform(
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]], const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]], device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]], const constant MLXConvParams<2>& params [[buffer(2)]],
@ -539,7 +530,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]], uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) { uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid; (void)lid;
using WGT = WinogradTransforms<M, R, 8>; using WGT = WinogradTransforms<M, R, 8>;
@ -572,9 +562,8 @@ template <typename T,
int bw = M * tid.x + kw; int bw = M * tid.x + kw;
// Move to the correct input tile // Move to the correct input tile
out_out += tid.z * params.out_strides[0] out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
+ bh * params.out_strides[1] bw * params.out_strides[2];
+ bw * params.out_strides[2];
// Pre compute strides // Pre compute strides
int jump_in[TH][TW]; int jump_in[TH][TW];
@ -582,24 +571,27 @@ template <typename T,
for (int h = 0; h < TH; h++) { for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) { for (int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; jump_in[h][w] =
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
} }
} }
// out_in is stored interleaved (A x A x tiles x O) // out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn; size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1; size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; const device T* out_in_0 =
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 =
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory // Prepare shared memory
threadgroup T Os[M][M][BO]; threadgroup T Os[M][M][BO];
// Loop over O // Loop over O
for (int bo = 0; bo < params.O; bo += BO) { for (int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
@ -633,12 +625,12 @@ template <typename T,
out_in_0 += BO; out_in_0 += BO;
out_in_1 += BO; out_in_1 += BO;
} }
} }
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ #define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\ template [[host_name("winograd_conv_2d_output_transform_" #name \
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\ "_bo" #bo)]] [[kernel]] void \
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
const device itype* out_in [[buffer(0)]], \ const device itype* out_in [[buffer(0)]], \
device itype* out_out [[buffer(1)]], \ device itype* out_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \ const constant MLXConvParams<2>& params [[buffer(2)]], \
@ -648,10 +640,12 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]); uint simd_lane_id [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_winograd_conv_2d(name, itype) \ #define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32) instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
// clang-format off
instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half); instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@ -49,7 +49,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides); auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
@ -76,7 +78,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
@ -144,23 +147,22 @@ template <typename T, typename U>
} }
#define instantiate_copy(name, itype, otype, ctype) \ #define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \ #define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void copy_g_nd<itype, otype, dims>( \ copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \ template [[host_name("g" name "_" #dims)]] [[kernel]] void \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \ copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
@ -168,44 +170,40 @@ template <typename T, typename U>
constant const int64_t* dst_strides [[buffer(4)]], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \ #define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \ constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \ template [[host_name("g" name "_1")]] [[kernel]] void \
[[kernel]] void copy_gg_nd1<itype, otype>( \ copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \ constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \ constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \ template [[host_name("g" name "_2")]] [[kernel]] void \
[[kernel]] void copy_gg_nd2<itype, otype>( \ copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \ uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \ template [[host_name("g" name "_3")]] [[kernel]] void \
[[kernel]] void copy_gg_nd3<itype, otype>( \ copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
@ -214,10 +212,8 @@ template <typename T, typename U>
instantiate_copy_g_dim(name, itype, otype, 4) \ instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5) instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \ #define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
@ -225,8 +221,7 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \ template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
@ -235,12 +230,14 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \ instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \ instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \ instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
// clang-format off
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \
@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half) instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float) instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t) instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t) instantiate_copy_itype(complex64, complex64_t) // clang-format on

View File

@ -6,9 +6,8 @@
// - VkFFT (https://github.com/DTolm/VkFFT) // - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) // - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_math>
#include <metal_common> #include <metal_common>
#include <metal_math>
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
} }
// single threaded radix2 implemetation // single threaded radix2 implemetation
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i]; float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m]; float2 x_1 = read_buf[i + m];
@ -53,7 +57,12 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
} }
// single threaded radix4 implemetation // single threaded radix4 implemetation
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i]; float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m]; float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2 * m]; float2 x_2 = read_buf[i + 2 * m];
@ -94,7 +103,6 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
write_buf[j + 3 * p] = y_3; write_buf[j + 3 * p] = y_3;
} }
// Each FFT is computed entirely in shared GPU memory. // Each FFT is computed entirely in shared GPU memory.
// //
// N is decomposed into radix-2 and radix-4 DFTs: // N is decomposed into radix-2 and radix-4 DFTs:
@ -111,7 +119,6 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
device float2* out [[buffer(1)]], device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]], uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) { uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch // Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n; int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on // The index in the DFT we're working on
@ -172,24 +179,21 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
} }
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \ #define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] \ template [[host_name("fft_" #name)]] [[kernel]] void \
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \ fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \ const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \ device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \ uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]); uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2. // Explicitly define kernels for each power of 2.
// clang-format off
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1) instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1) instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
instantiate_fft(16, 16, 0, 2) instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
instantiate_fft(32, 32, 1, 2) instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3)
instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4) instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5) instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory. // 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n. // TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5) instantiate_fft(2048, 2048, 1, 5) // clang-format on

View File

@ -24,7 +24,6 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices, const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x; auto ind_idx = index.x;
auto ind_offset = index.y; auto ind_offset = index.y;
@ -43,17 +42,14 @@ METAL_FUNC void gather_impl(
indices.ndim); indices.ndim);
} }
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx( auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax]; src_idx += idx_val * src_strides[ax];
} }
auto src_offset = elem_to_loc( auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x; size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx]; out[out_idx] = src[src_offset + src_idx];
} }
#define make_gather_impl(IDX_ARG, IDX_ARR) \ #define make_gather_impl(IDX_ARG, IDX_ARR) \
@ -69,15 +65,10 @@ template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
const constant int* idx_shapes [[buffer(7)]], \ const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \ const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \ const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \ IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \ uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \ Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \ {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\ \
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \ return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \ src, \
@ -94,16 +85,8 @@ template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n) #define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0) make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(1) make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10) make_gather(10)
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
@ -111,8 +94,8 @@ make_gather(10)
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ #define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \ gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \ const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \ device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \ const constant int* src_shape [[buffer(2)]], \
@ -123,13 +106,14 @@ template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
const constant int* idx_shapes [[buffer(7)]], \ const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \ const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \ const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \ IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ #define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \ #define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
@ -148,8 +132,9 @@ instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0) instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0) instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0) instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \ #define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \ instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \ instantiate_gather4(name, src_type, ind_type, 2) \
@ -160,8 +145,9 @@ instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
instantiate_gather4(name, src_type, ind_type, 7) \ instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \ instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \ instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
// clang-format off
#define instantiate_gather(name, src_type) \ #define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \ instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \ instantiate_gather3(#name "uint8", src_type, uint8_t) \
@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t) instantiate_gather(int64, int64_t)
instantiate_gather(float16, half) instantiate_gather(float16, half)
instantiate_gather(float32, float) instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t) instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
@ -25,7 +25,6 @@ template <
const int TN, /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVKernel { struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
@ -35,15 +34,20 @@ struct GEMVKernel {
// //
// 1. A thread loads TN elements each from mat along TM contiguous rows // 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector // and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block // 2. The thread then multiplies and adds to accumulate its local result for
// 3. At the end, each thread has accumulated results over all blocks across the rows // the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows
// These are then summed up across the threadgroup // These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs // 4. Each threadgroup writes its accumulated BN * TN outputs
// //
// Edge case handling: // Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix // - The threadgroup with the largest tid will have blocks that exceed the
// * The blocks that start outside the matrix are never read (thread results remain zero) // matrix
// * The last thread that partially overlaps with the matrix is shifted inwards // * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards
// such that the thread block fits exactly in the matrix // such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
@ -64,7 +68,6 @@ struct GEMVKernel {
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler // Appease compiler
(void)lid; (void)lid;
@ -91,14 +94,12 @@ struct GEMVKernel {
// Loop over in_vec in blocks of BN * TN // Loop over in_vec in blocks of BN * TN
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use // Prefetch in_vector for threadgroup use
if (simd_gid == 0) { if (simd_gid == 0) {
// Main load loop // Main load loop
if (bn + TN <= in_vec_size) { if (bn + TN <= in_vec_size) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn]; in_vec_block[tn] = in_vec[bn + tn];
@ -110,7 +111,6 @@ struct GEMVKernel {
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0; in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
} }
} }
} }
@ -125,7 +125,6 @@ struct GEMVKernel {
// Per thread work loop // Per thread work loop
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
// Load for the row // Load for the row
if (bn + TN <= in_vec_size) { if (bn + TN <= in_vec_size) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@ -136,7 +135,8 @@ struct GEMVKernel {
} else { // Edgecase } else { // Edgecase
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); int col_idx =
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * marix_ld + col_idx]; inter[tn] = mat[tm * marix_ld + col_idx];
} }
} }
@ -145,7 +145,6 @@ struct GEMVKernel {
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn]; result[tm] += inter[tn] * v_coeff[tn];
} }
} }
} }
@ -157,22 +156,17 @@ struct GEMVKernel {
// Write outputs // Write outputs
if (simd_lid == 0) { if (simd_lid == 0) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
if (kDoAxpby) { if (kDoAxpby) {
out_vec[out_row + tm] = out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
static_cast<T>(alpha) * result[tm] +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride]; static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else { } else {
out_vec[out_row + tm] = result[tm]; out_vec[out_row + tm] = result[tm];
} }
} }
} }
} }
}; };
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -187,7 +181,6 @@ template <
const int TN, /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVTKernel { struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups // into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN) // - Every thread works on a block of (TM, TN)
@ -195,18 +188,22 @@ struct GEMVTKernel {
// //
// 1. A thread loads TN elements each from mat along TM contiguous rows // 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector // and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block // 2. The thread then multiplies and adds to accumulate its local result for
// 3. At the end, each thread has accumulated results over all blocks across the rows // the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows
// These are then summed up across the threadgroup // These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs // 4. Each threadgroup writes its accumulated BN * TN outputs
// //
// Edge case handling: // Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix // - The threadgroup with the largest tid will have blocks that exceed the
// * The blocks that start outside the matrix are never read (thread results remain zero) // matrix
// * The last thread that partially overlaps with the matrix is shifted inwards // * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards
// such that the thread block fits exactly in the matrix // such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
static METAL_FUNC void run( static METAL_FUNC void run(
@ -225,7 +222,6 @@ struct GEMVTKernel {
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler // Appease compiler
(void)simd_gid; (void)simd_gid;
(void)simd_lid; (void)simd_lid;
@ -243,7 +239,6 @@ struct GEMVTKernel {
// Edgecase handling // Edgecase handling
if (out_col < out_vec_size) { if (out_col < out_vec_size) {
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop // Per thread accumulation main loop
@ -254,7 +249,6 @@ struct GEMVTKernel {
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
if (bm + TM <= in_vec_size) { if (bm + TM <= in_vec_size) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int tm = 0; tm < TM; tm++) { for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm]; v_coeff[tm] = in_vec[bm + tm];
@ -280,11 +274,9 @@ struct GEMVTKernel {
for (int tn = 0; tn < TN; tn++) { for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn]; result[tn] += v_coeff[tm] * inter[tn];
} }
} }
} }
} }
} }
// Threadgroup collection // Threadgroup collection
@ -298,10 +290,8 @@ struct GEMVTKernel {
// Threadgroup accumulation and writing out results // Threadgroup accumulation and writing out results
if (lid.y == 0 && out_col < out_vec_size) { if (lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int i = 1; i < BM; i++) { for (int i = 1; i < BM; i++) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) { for (int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j]; result[j] += tgp_results[i * TN + j];
@ -310,10 +300,8 @@ struct GEMVTKernel {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) { for (int j = 0; j < TN; j++) {
if (kDoAxpby) { if (kDoAxpby) {
out_vec[out_col + j] = out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
static_cast<T>(alpha) * result[j] +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride]; static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else { } else {
out_vec[out_col + j] = result[j]; out_vec[out_col + j] = result[j];
@ -355,7 +343,6 @@ template <
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>; using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
@ -394,15 +381,13 @@ template <
tid, tid,
lid, lid,
simd_gid, simd_gid,
simd_lid simd_lid);
);
} }
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ #define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \ "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \ const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \ const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \ const device itype* bias [[buffer(2)]], \
@ -430,9 +415,8 @@ template <
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
#define instantiate_gemv_blocks(name, itype) \ #define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \ instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
instantiate_gemv(name, itype, 4, 32, 4, 4) \ name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(float16, half);
@ -470,7 +454,6 @@ template <
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>; using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
@ -509,14 +492,13 @@ template <
tid, tid,
lid, lid,
simd_gid, simd_gid,
simd_lid simd_lid);
);
} }
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ #define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \ "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \ const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \ const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \ const device itype* bias [[buffer(2)]], \
@ -537,20 +519,23 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_t_blocks(name, itype) \ #define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4) instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) { if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer; thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i]; out[i] =
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
} }
} }
} }
@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = (x[r + i] - mean) * normalizer; float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)]; out[r + i] =
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) { if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[r + i] - mean) * normalizer; float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)]; out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
b[b_stride * (i + r)];
} }
} }
} }
@ -323,7 +326,8 @@ template <typename T, int N_READS = RMS_N_READS>
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer; thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) - gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2); thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]); gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
} }
@ -331,7 +335,8 @@ template <typename T, int N_READS = RMS_N_READS>
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) { if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer; thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) - gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2); thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]); gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
} }
@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer; float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride]; float wi = w[(i + r) * w_stride];
float gi = g[i + r]; float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) - gx[i + r] = static_cast<T>(
xi * meanwgxc * normalizer2); normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi); gw[i + r] = static_cast<T>(gi * xi);
} }
} else { } else {
@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer; float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride]; float wi = w[(i + r) * w_stride];
float gi = g[i + r]; float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) - gx[i + r] = static_cast<T>(
xi * meanwgxc * normalizer2); normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi); gw[i + r] = static_cast<T>(gi * xi);
} }
} }
@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_layer_norm(float32, float) instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half) instantiate_layer_norm(float16, half)
instantiate_layer_norm(bfloat16, bfloat16_t) instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on
// clang-format on

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
@ -15,10 +15,11 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) { inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0; U sum = 0;
@ -54,7 +55,9 @@ inline U load_vector(const device T *x, thread U *x_thread) {
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U sum = 0; U sum = 0;
@ -98,29 +101,36 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) { inline U qdot(
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0; U accum = 0;
if (bits == 2) { if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
accum += ( accum +=
x_thread[4*i] * (w[i] & 0x03) (x_thread[4 * i] * (w[i] & 0x03) +
+ x_thread[4*i+1] * (w[i] & 0x0c) x_thread[4 * i + 1] * (w[i] & 0x0c) +
+ x_thread[4*i+2] * (w[i] & 0x30) x_thread[4 * i + 2] * (w[i] & 0x30) +
+ x_thread[4*i+3] * (w[i] & 0xc0)); x_thread[4 * i + 3] * (w[i] & 0xc0));
} }
} }
else if (bits == 4) { else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
accum += ( accum +=
x_thread[4*i] * (ws[i] & 0x000f) (x_thread[4 * i] * (ws[i] & 0x000f) +
+ x_thread[4*i+1] * (ws[i] & 0x00f0) x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
+ x_thread[4*i+2] * (ws[i] & 0x0f00) x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
+ x_thread[4*i+3] * (ws[i] & 0xf000)); x_thread[4 * i + 3] * (ws[i] & 0xf000));
} }
} }
@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) { inline U qdot_safe(
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum,
int N) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0; U accum = 0;
if (bits == 2) { if (bits == 2) {
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
accum += ( accum +=
x_thread[4*i] * (w[i] & 0x03) (x_thread[4 * i] * (w[i] & 0x03) +
+ x_thread[4*i+1] * (w[i] & 0x0c) x_thread[4 * i + 1] * (w[i] & 0x0c) +
+ x_thread[4*i+2] * (w[i] & 0x30) x_thread[4 * i + 2] * (w[i] & 0x30) +
+ x_thread[4*i+3] * (w[i] & 0xc0)); x_thread[4 * i + 3] * (w[i] & 0xc0));
} }
} }
else if (bits == 4) { else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
accum += ( accum +=
x_thread[4*i] * (ws[i] & 0x000f) (x_thread[4 * i] * (ws[i] & 0x000f) +
+ x_thread[4*i+1] * (ws[i] & 0x00f0) x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
+ x_thread[4*i+2] * (ws[i] & 0x0f00) x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
+ x_thread[4*i+3] * (ws[i] & 0xf000)); x_thread[4 * i + 3] * (ws[i] & 0xf000));
} }
} }
@ -170,8 +188,11 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { inline void
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) { if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@ -202,11 +223,18 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
} }
template <typename U, int N, int bits> template <typename U, int N, int bits>
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { inline void
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) { if (bits == 2) {
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)}; U s[4] = {
scale,
scale / static_cast<U>(4.0f),
scale / static_cast<U>(16.0f),
scale / static_cast<U>(64.0f)};
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
@ -217,7 +245,11 @@ inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U*
else if (bits == 4) { else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w; const device uint16_t* ws = (const device uint16_t*)w;
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)}; U s[4] = {
scale,
scale / static_cast<U>(16.0f),
scale / static_cast<U>(256.0f),
scale / static_cast<U>(4096.0f)};
for (int i = 0; i < (N / 4); i++) { for (int i = 0; i < (N / 4); i++) {
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias; w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias; w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
@ -243,13 +275,20 @@ template <
short group_size, short group_size,
short bits> short bits>
struct QuantizedBlockLoader { struct QuantizedBlockLoader {
static_assert(BCOLS <= group_size, "The group size should be larger than the columns"); static_assert(
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns"); BCOLS <= group_size,
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); "The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits; MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS; MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld; const int src_ld;
@ -275,7 +314,8 @@ struct QuantizedBlockLoader {
ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_), : src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), tile_stride(
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
group_step_cnt(0), group_step_cnt(0),
group_stride(BROWS * src_ld / group_size), group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id), thread_idx(simd_group_id * 32 + simd_lane_id),
@ -294,7 +334,8 @@ struct QuantizedBlockLoader {
T scale = *scales; T scale = *scales;
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
} }
} }
@ -320,7 +361,8 @@ struct QuantizedBlockLoader {
T scale = *scales; T scale = *scales;
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); dequantize<T, pack_factor, bits>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
} }
} }
@ -357,7 +399,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
@ -373,7 +414,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread; w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
@ -384,7 +426,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
@ -407,7 +450,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
} }
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
@ -420,7 +462,6 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1; constexpr int packs_per_thread = 1;
@ -437,7 +478,8 @@ template <typename T, const int group_size, const int bits>
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
if (out_row >= out_vec_size) { if (out_row >= out_vec_size) {
@ -458,13 +500,15 @@ template <typename T, const int group_size, const int bits>
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
} }
w += block_size / pack_factor; w += block_size / pack_factor;
@ -472,11 +516,16 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size; biases += block_size / group_size;
x += block_size; x += block_size;
} }
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); const int remaining = clamp(
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining); static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
@ -506,13 +555,15 @@ template <typename T, const int group_size, const int bits>
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
} }
w += block_size / pack_factor; w += block_size / pack_factor;
@ -520,17 +571,23 @@ template <typename T, const int group_size, const int bits>
biases += block_size / group_size; biases += block_size / group_size;
x += block_size; x += block_size;
} }
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); const int remaining = clamp(
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining); static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining); result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
} }
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
@ -542,7 +599,6 @@ template <typename T, const int group_size, const int bits>
} }
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void qvm( [[kernel]] void qvm(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
@ -555,7 +611,6 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8; constexpr int num_simdgroups = 8;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int blocksize = SIMD_SIZE; constexpr int blocksize = SIMD_SIZE;
@ -590,7 +645,8 @@ template <typename T, const int group_size, const int bits>
bias = biases[(i + simd_lid) * out_vec_size_g]; bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w]; w_local = w[(i + simd_lid) * out_vec_size_w];
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result); qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
} }
if (static_cast<int>(i + simd_lid) < in_vec_size) { if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid]; x_local = x[i + simd_lid];
@ -603,7 +659,8 @@ template <typename T, const int group_size, const int bits>
bias = 0; bias = 0;
w_local = 0; w_local = 0;
} }
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result); qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup // Accumulate in the simdgroup
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@ -620,8 +677,14 @@ template <typename T, const int group_size, const int bits>
} }
} }
template <
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N> typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
[[kernel]] void qmm_t( [[kernel]] void qmm_t(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device uint32_t* w [[buffer(1)]],
@ -635,7 +698,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@ -647,9 +709,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>; using mma_t = mlx::steel::
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>; BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded]; threadgroup T Ws[BN * BK_padded];
@ -728,8 +800,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
} }
} }
template <
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits> typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
[[kernel]] void qmm_n( [[kernel]] void qmm_n(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device uint32_t* w [[buffer(1)]],
@ -743,7 +820,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@ -756,9 +832,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int BN_padded = (BN + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader // Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>; using mma_t = mlx::steel::
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>; BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; using loader_x_t = mlx::steel::
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t = QuantizedBlockLoader<
T,
BK,
BN,
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded]; threadgroup T Ws[BK * BN_padded];
@ -847,10 +933,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
} }
} }
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \ #define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \ template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \ "_fast")]] [[kernel]] void \
qmv_fast<itype, group_size, bits, packs_per_thread>( \
const device uint32_t* w [[buffer(0)]], \ const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \ const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \ const device itype* biases [[buffer(2)]], \
@ -862,11 +948,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \ #define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
// clang-format off
instantiate_qmv_fast_types(128, 2, 1) instantiate_qmv_fast_types(128, 2, 1)
instantiate_qmv_fast_types(128, 4, 2) instantiate_qmv_fast_types(128, 4, 2)
instantiate_qmv_fast_types(128, 8, 2) instantiate_qmv_fast_types(128, 8, 2)
@ -875,11 +963,12 @@ instantiate_qmv_fast_types( 64, 4, 2)
instantiate_qmv_fast_types( 64, 8, 2) instantiate_qmv_fast_types( 64, 8, 2)
instantiate_qmv_fast_types( 32, 2, 1) instantiate_qmv_fast_types( 32, 2, 1)
instantiate_qmv_fast_types( 32, 4, 2) instantiate_qmv_fast_types( 32, 4, 2)
instantiate_qmv_fast_types( 32, 8, 2) instantiate_qmv_fast_types( 32, 8, 2) // clang-format on
#define instantiate_qmv(name, itype, group_size, bits) \ #define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmv_" #name "_gs_" #group_size \
[[kernel]] void qmv<itype, group_size, bits>( \ "_b_" #bits)]] [[kernel]] void \
qmv<itype, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \ const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \ const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \ const device itype* biases [[buffer(2)]], \
@ -891,11 +980,13 @@ instantiate_qmv_fast_types( 32, 8, 2)
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_qmv_types(group_size, bits) \ #define instantiate_qmv_types(group_size, bits) \
instantiate_qmv(float32, float, group_size, bits) \ instantiate_qmv(float32, float, group_size, bits) \
instantiate_qmv(float16, half, group_size, bits) \ instantiate_qmv(float16, half, group_size, bits) \
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qmv_types(128, 2) instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4) instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8) instantiate_qmv_types(128, 8)
@ -904,11 +995,12 @@ instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8) instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2) instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4) instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8) instantiate_qmv_types( 32, 8) // clang-format on
#define instantiate_qvm(name, itype, group_size, bits) \ #define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qvm_" #name "_gs_" #group_size \
[[kernel]] void qvm<itype, group_size, bits>( \ "_b_" #bits)]] [[kernel]] void \
qvm<itype, group_size, bits>( \
const device itype* x [[buffer(0)]], \ const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \ const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \ const device itype* scales [[buffer(2)]], \
@ -920,11 +1012,13 @@ instantiate_qmv_types( 32, 8)
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_qvm_types(group_size, bits) \ #define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float32, float, group_size, bits) \ instantiate_qvm(float32, float, group_size, bits) \
instantiate_qvm(float16, half, group_size, bits) \ instantiate_qvm(float16, half, group_size, bits) \
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qvm_types(128, 2) instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4) instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8) instantiate_qvm_types(128, 8)
@ -933,11 +1027,12 @@ instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8) instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2) instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4) instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8) instantiate_qvm_types( 32, 8) // clang-format on
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ #define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \ "_alN_" #aligned_N)]] [[kernel]] void \
qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \ const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \ const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \ const device itype* scales [[buffer(2)]], \
@ -951,14 +1046,16 @@ instantiate_qvm_types( 32, 8)
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_qmm_t_types(group_size, bits) \ #define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits, false) \ instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \ instantiate_qmm_t(float16, half, group_size, bits, false) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float32, float, group_size, bits, true) \ instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \ instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
// clang-format off
instantiate_qmm_t_types(128, 2) instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4) instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8) instantiate_qmm_t_types(128, 8)
@ -967,11 +1064,12 @@ instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8) instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2) instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4) instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8) instantiate_qmm_t_types( 32, 8) // clang-format on
#define instantiate_qmm_n(name, itype, group_size, bits) \ #define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmm_n_" #name "_gs_" #group_size \
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \ "_b_" #bits)]] [[kernel]] void \
qmm_n<itype, 32, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \ const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \ const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \ const device itype* scales [[buffer(2)]], \
@ -985,11 +1083,13 @@ instantiate_qmm_t_types( 32, 8)
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_qmm_n_types(group_size, bits) \ #define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float32, float, group_size, bits) \ instantiate_qmm_n(float32, float, group_size, bits) \
instantiate_qmm_n(float16, half, group_size, bits) \ instantiate_qmm_n(float16, half, group_size, bits) \
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_qmm_n_types(128, 2) instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4) instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8) instantiate_qmm_n_types(128, 8)
@ -998,4 +1098,4 @@ instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8) instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2) instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4) instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8) instantiate_qmm_n_types( 32, 8) // clang-format on

View File

@ -4,8 +4,7 @@
static constexpr constant uint32_t rotations[2][4] = { static constexpr constant uint32_t rotations[2][4] = {
{13, 15, 26, 6}, {13, 15, 26, 6},
{17, 29, 16, 24} {17, 29, 16, 24}};
};
union rbits { union rbits {
uint2 val; uint2 val;
@ -13,7 +12,6 @@ union rbits {
}; };
rbits threefry2x32_hash(const thread uint2& key, uint2 count) { rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v; rbits v;

View File

@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal; using namespace metal;
@ -60,7 +60,6 @@ METAL_FUNC U per_thread_all_reduce(
// All reduce kernel // All reduce kernel
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most // NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to // 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions. // complete the reduction in two steps of simd-level reductions.
@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group // Reduction within simd group
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) { uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types) // Reduction within simd group (simd_add isn't supported for uint64/int64
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { // types)
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
} }
// Write simd group reduction results to local memory // Write simd group reduction results to local memory
@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction of simdgroup reduction results within threadgroup. // Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init; total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
} }
@ -139,8 +141,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \ template [[host_name("all_reduce_" #name)]] [[kernel]] void \
[[kernel]] void all_reduce<itype, otype, op>( \ all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \ device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \ const device size_t& in_size [[buffer(2)]], \
@ -152,8 +154,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ #define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \ template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \ all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \ const device size_t& in_size [[buffer(2)]], \
@ -175,6 +177,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ #define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>) instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type // special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal; using namespace metal;
@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
const constant size_t* non_col_strides [[buffer(10)]], const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]], const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) { uint tid [[thread_position_in_grid]]) {
// Appease the compiler // Appease the compiler
(void)out_size; (void)out_size;
@ -41,7 +40,8 @@ template <typename T, typename U, typename Op>
ndim - non_col_ndim); ndim - non_col_ndim);
for (uint i = 0; i < non_col_reductions; i++) { for (uint i = 0; i < non_col_reductions; i++) {
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); size_t in_idx =
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]); U val = static_cast<U>(in[in_idx]);
@ -53,8 +53,8 @@ template <typename T, typename U, typename Op>
} }
#define instantiate_col_reduce_small(name, itype, otype, op) \ #define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] \ template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
[[kernel]] void col_reduce_small<itype, otype, op>( \ col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -125,12 +125,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x; auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc( auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
Op op; Op op;
if (out_idx < out_size) { if (out_idx < out_size) {
@ -144,7 +139,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
lid.xy, lid.xy,
lsize.xy); lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously. // Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) { if (lid.y == 0) {
op.atomic_update(out, val, out_idx); op.atomic_update(out, val, out_idx);
} }
@ -168,12 +164,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lsize [[threads_per_threadgroup]], uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) { uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x; auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc( auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if (out_idx < out_size) { if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>( U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
@ -186,7 +177,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
lid.xy, lid.xy,
lsize.xy); lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously. // Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) { if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y); uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z); uint tgsize_z = ceildiv(gsize.z, lsize.z);
@ -196,8 +188,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
} }
#define instantiate_col_reduce_general(name, itype, otype, op) \ #define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \ template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
[[kernel]] void col_reduce_general<itype, otype, op>( \ col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \ device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -212,8 +204,9 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lsize [[threads_per_threadgroup]]); uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ #define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \ template \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \ [[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -233,14 +226,17 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
// Instantiations // Instantiations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \ #define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \ instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>) instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ #define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \ instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal; using namespace metal;
@ -18,16 +18,15 @@ template <typename T, typename Op>
} }
#define instantiate_init_reduce(name, otype, op) \ #define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \ template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
[[kernel]] void init_reduce<otype, op>( \ device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \ #define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>) instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And) instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or) instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal; using namespace metal;
@ -22,7 +22,6 @@ template <typename T, typename U, typename Op>
const constant size_t* strides [[buffer(6)]], const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]], const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) { uint lid [[thread_position_in_grid]]) {
Op op; Op op;
uint out_idx = lid; uint out_idx = lid;
@ -60,7 +59,6 @@ template <typename T, typename U, typename Op>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op; Op op;
uint out_idx = simd_per_group * tid + simd_group_id; uint out_idx = simd_per_group * tid + simd_group_id;
@ -81,24 +79,22 @@ template <typename T, typename U, typename Op>
} }
else if (short(non_row_reductions) >= 32) { else if (short(non_row_reductions) >= 32) {
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) { for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx; const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) { for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val); total_val = op(static_cast<U>(in_row[i]), total_val);
} }
} }
} }
else { else {
const short n_reductions =
const short n_reductions = short(reduction_size) * short(non_row_reductions); short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size; const short reductions_per_thread =
(n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread; const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions); const short r_ed = short(non_row_reductions);
@ -110,20 +106,16 @@ template <typename T, typename U, typename Op>
if (r_st < r_jump) { if (r_st < r_jump) {
for (short r = r_st; r < r_ed; r += r_jump) { for (short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx; const device T* in_row = in + in_idx;
for (short i = i_st; i < i_ed; i += i_jump) { for (short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val); total_val = op(static_cast<U>(in_row[i]), total_val);
} }
} }
} }
} }
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
@ -132,8 +124,8 @@ template <typename T, typename U, typename Op>
} }
#define instantiate_row_reduce_small(name, itype, otype, op) \ #define instantiate_row_reduce_small(name, itype, otype, op) \
template[[host_name("row_reduce_general_small_" #name)]] \ template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
[[kernel]] void row_reduce_general_small<itype, otype, op>( \ row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -143,8 +135,8 @@ template <typename T, typename U, typename Op>
const constant size_t* strides [[buffer(6)]], \ const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \ const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \ uint lid [[thread_position_in_grid]]); \
template[[host_name("row_reduce_general_med_" #name)]] \ template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
[[kernel]] void row_reduce_general_med<itype, otype, op>( \ row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -233,13 +225,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions; (void)non_row_reductions;
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
@ -278,13 +278,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions; (void)non_row_reductions;
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types // Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size / 2; i > 0; i /= 2) { for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
@ -312,9 +320,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \ instantiate_row_reduce_small(name, itype, otype, op) template \
template [[host_name("row_reduce_general_" #name)]] \ [[host_name("row_reduce_general_" #name)]] [[kernel]] void \
[[kernel]] void row_reduce_general<itype, otype, op>( \ row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \ device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -331,9 +339,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ #define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \ instantiate_row_reduce_small(name, itype, otype, op) template \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \ [[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \ row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_size [[buffer(2)]], \
@ -350,7 +358,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Instantiations // Instantiations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -361,11 +368,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ #define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>) instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
gw += gid * axis_size + lid * N_READS; gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer); gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) { if ((lid * N_READS + i) < axis_size) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer); gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
} }
} }
@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)]; float wi = w[w_stride * (i + r)];
float gi = g[i + r]; float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3); gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer); gw[i + r] = static_cast<T>(gi * xi * normalizer);
} }
} else { } else {
@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)]; float wi = w[w_stride * (i + r)];
float gi = g[i + r]; float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3); gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer); gw[i + r] = static_cast<T>(gi * xi * normalizer);
} }
} }
@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_rms(float32, float) instantiate_rms(float32, float)
instantiate_rms(float16, half) instantiate_rms(float16, half)
instantiate_rms(bfloat16, bfloat16_t) instantiate_rms(bfloat16, bfloat16_t) // clang-format on
// clang-format on

View File

@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
uint in_index_1, in_index_2; uint in_index_1, in_index_2;
uint out_index_1, out_index_2; uint out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2]; out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2]; in_index_2 = in_index_1 + grid.x * strides[2];
@ -57,8 +60,8 @@ template <typename T, bool traditional, bool forward>
} }
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] \ template [[host_name("rope_" #name)]] [[kernel]] void \
[[kernel]] void rope<type, traditional, forward>( \ rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \ const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \ device type* out [[buffer(1)]], \
constant const size_t strides[3], \ constant const size_t strides[3], \
@ -69,6 +72,7 @@ template <typename T, bool traditional, bool forward>
uint3 pos [[thread_position_in_grid]], \ uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); uint3 grid [[threads_per_grid]]);
// clang-format off
instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true) instantiate_rope(traditional_float32, float, true, true)
@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false) instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false) instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@ -1,11 +1,17 @@
#include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal; using namespace metal;
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS> template <
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]], typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]], const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]], const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]], const device uint64_t& L [[buffer(3)]],
@ -28,24 +34,31 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
kv_head_offset_factor = tid.x / q_kv_head_ratio; kv_head_offset_factor = tid.x / q_kv_head_ratio;
} }
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS; constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) { for (uint i = 0; i < 8; i++) {
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding // TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK); const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset; const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ; device T4* simdgroupQueryData = (device T4*)baseQ;
@ -54,7 +67,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
float threadAccum[ACCUM_PER_GROUP]; float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) { for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY; threadAccum[threadAccumIndex] = -INFINITY;
} }
@ -62,14 +76,16 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4; T4 thread_data_x4;
T4 thread_data_y4; T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) { if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id); thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) { for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK; const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET; const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow; device T4* keysData = (device T4*)baseKRow;
@ -81,9 +97,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
} else { } else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id); thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST; const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset; const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) { for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK; const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET; const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow; device T4* keysData = (device T4*)baseKRow;
@ -97,7 +115,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) { for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]); thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4); thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
@ -115,11 +137,13 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
float lse = 0.f; float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH]; float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) { for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY); pvals[accum_array_iter] = float4(-INFINITY);
} }
@ -192,10 +216,15 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
uint matrix_load_loop_iter = 0; uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) { for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp; simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; ulong simdgroup_matrix_offset =
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true); simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST; const ulong elemsPerRowSmem = TILE_SIZE_CONST;
@ -208,10 +237,12 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
if (TILE_SIZE_CONST == 64) { if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y); T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0; uint loop_iter = 0;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id); T2 v_local = *(smemV2 + simd_lane_id);
@ -220,16 +251,20 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val); T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++; loop_iter++;
} }
} }
if (TILE_SIZE_CONST > 64) { if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128; constexpr const size_t TILE_SIZE_CONST_DIV_128 =
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; (TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0; uint loop_iter = 0;
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f; T row_sum = 0.f;
@ -242,7 +277,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
} }
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum); row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++; loop_iter++;
} }
} }
@ -256,31 +292,46 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
for (size_t col = 0; col < MATRIX_COLS; col++) { for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start; int32_t tile_start;
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp; simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); ulong2 matrixOrigin =
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false); simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
}; };
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L); const int32_t INT_L = int32_t(L);
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) { for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK; const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index; const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST; const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST; const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id; const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; const uint scalar_offset_gmem =
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem; row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem)); T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata; smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS; smem_col_index += NSIMDGROUPS;
@ -291,9 +342,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
if (TILE_SIZE_CONST == 64) { if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y); T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id; for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) { smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id); T2 v_local = *(smemV2 + simd_lane_id);
@ -305,22 +358,25 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
} }
if (TILE_SIZE_CONST > 64) { if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0; uint loop_count = 0;
for(size_t row_index = simd_group_id; for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) { row_index += NSIMDGROUPS) {
T row_sum = 0.f; T row_sum = 0.f;
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) { for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]); T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local); row_sum += dot(p_local, v_local);
} }
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum); row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum); oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++; loop_count++;
} }
} }
@ -332,22 +388,32 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
if (simd_group_id == 0) { if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id); float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals; oPartialGmemVec4[simd_lane_id] = vals;
} }
if (simd_group_id == 0 && simd_lane_id == 0) { if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y; const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex; const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse; p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax; p_maxes[gmem_partial_scalar_offset] = groupMax;
} }
} }
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \ #define instantiate_fast_inference_sdpa_to_partials_kernel( \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \ itype, itype2, itype4, tile_size, nsimdgroups) \
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \ template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \ const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \ const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \ const device itype* V [[buffer(2)]], \
@ -361,21 +427,55 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]); uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \ instantiate_fast_inference_sdpa_to_partials_shapes_helper(
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \ float,
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \ float2,
float4,
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64); 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128); instantiate_fast_inference_sdpa_to_partials_shapes_helper(
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256); float,
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512); float2,
float4,
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64); 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128); instantiate_fast_inference_sdpa_to_partials_shapes_helper(
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256); float,
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512); float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T> template <typename T>
void fast_inference_sdpa_reduce_tiles_template( void fast_inference_sdpa_reduce_tiles_template(
@ -386,12 +486,13 @@ void fast_inference_sdpa_reduce_tiles_template(
device T* O [[buffer(4)]], device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128; constexpr const int DK = 128;
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows; const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows; const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max value of KV TILES. // reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128; constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve]; float p_lse_regs[reserve];
float p_rowmax_regs[reserve]; float p_rowmax_regs[reserve];
@ -411,7 +512,9 @@ void fast_inference_sdpa_reduce_tiles_template(
denom += weights[i]; denom += weights[i];
} }
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES; const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
float o_value = 0.f; float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) { for (size_t i = 0; i < params.KV_TILES; i++) {
@ -423,7 +526,6 @@ void fast_inference_sdpa_reduce_tiles_template(
return; return;
} }
kernel void fast_inference_sdpa_reduce_tiles_float( kernel void fast_inference_sdpa_reduce_tiles_float(
const device float* O_partials [[buffer(0)]], const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]], const device float* p_lse [[buffer(1)]],
@ -431,10 +533,9 @@ kernel void fast_inference_sdpa_reduce_tiles_float(
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]], device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) uint3 lid [[thread_position_in_threadgroup]]) {
{ fast_inference_sdpa_reduce_tiles_template<float>(
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params, O_partials, p_lse, p_maxes, params, O, tid, lid);
O, tid, lid);
} }
kernel void fast_inference_sdpa_reduce_tiles_half( kernel void fast_inference_sdpa_reduce_tiles_half(
@ -444,8 +545,7 @@ kernel void fast_inference_sdpa_reduce_tiles_half(
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]], device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) uint3 lid [[thread_position_in_threadgroup]]) {
{ fast_inference_sdpa_reduce_tiles_template<half>(
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params, O_partials, p_lse, p_maxes, params, O, tid, lid);
O, tid, lid);
} }

View File

@ -127,10 +127,16 @@ inline void load_unsafe(U values[N_READS], const device T * input) {
} }
template <typename T, typename U, int N_READS, bool reverse> template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) { inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) { if (reverse) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init; values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@ -169,8 +175,13 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
} }
} }
template <
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse> typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan( [[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -195,14 +206,16 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
U values[N_READS]; U values[N_READS];
threadgroup U simdgroup_sums[32]; threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize) // Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block // Read block
// Compute inclusive scan of the block // Compute inclusive scan of the block
// Compute inclusive scan per thread // Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup // Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM // Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums // Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value // Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block // Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
@ -212,15 +225,22 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Read the values // Read the values
if (reverse) { if (reverse) {
if ((offset + N_READS) < axis_size) { if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS); load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else { } else {
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init); load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
} }
} else { } else {
if ((offset + N_READS) < axis_size) { if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset); load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else { } else {
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init); load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
} }
} }
@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if (reverse) { if (reverse) {
if (inclusive) { if (inclusive) {
if ((offset + N_READS) < axis_size) { if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS); write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else { } else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size); write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
} }
} else { } else {
if (lid == 0 && offset == 0) { if (lid == 0 && offset == 0) {
out[axis_size - 1] = Op::init; out[axis_size - 1] = Op::init;
} }
if ((offset + N_READS + 1) < axis_size) { if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS); write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else { } else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size); write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
} }
} }
} else { } else {
@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS) < axis_size) { if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset); write_unsafe<U, N_READS, reverse>(values, out + offset);
} else { } else {
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size); write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
} }
} else { } else {
if (lid == 0 && offset == 0) { if (lid == 0 && offset == 0) {
@ -284,7 +312,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS + 1) < axis_size) { if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1); write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else { } else {
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size); write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
} }
} }
} }
@ -298,7 +327,13 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
} }
} }
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse> template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan( [[kernel]] void strided_scan(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -334,14 +369,17 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Read in SM // Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) { if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i]; read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) { if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i]; read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else { } else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init; read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
} }
} }
} }
@ -349,9 +387,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Read strided into registers // Read strided into registers
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
} }
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously? // Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan // Perform the scan
@ -363,7 +403,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Write to SM // Write to SM
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i]; read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -392,21 +433,24 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
} }
if (check_index_y < axis_size && (index_x + N_READS) < stride) { if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) { if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
} }
} }
} }
} }
} }
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \ #define instantiate_contiguous_scan( \
template [[host_name("contiguous_scan_" #name)]] \ name, itype, otype, op, inclusive, reverse, nreads) \
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \ template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \ const constant size_t& axis_size [[buffer(2)]], \
@ -417,9 +461,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \ #define instantiate_strided_scan( \
template [[host_name("strided_scan_" #name)]] \ name, itype, otype, op, inclusive, reverse, nreads) \
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \ template [[host_name("strided_scan_" #name)]] [[kernel]] void \
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \ const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \ const constant size_t& axis_size [[buffer(2)]], \
@ -429,7 +474,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
uint2 lsize [[threads_per_threadgroup]], \ uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]); uint simd_size [[threads_per_simdgroup]]);
// clang-format off
#define instantiate_scan_helper(name, itype, otype, op, nreads) \ #define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
@ -438,8 +483,9 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
// clang-format off
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4) instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) //instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on

View File

@ -13,7 +13,7 @@ using namespace metal;
// Scatter kernel // Scatter kernel
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \ template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl( METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]], const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]], device mlx_atomic<T>* out [[buffer(2)]],
@ -22,13 +22,11 @@ METAL_FUNC void scatter_1d_index_impl(
const constant size_t& upd_size [[buffer(5)]], const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers, const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
Op op; Op op;
uint out_idx = 0; uint out_idx = 0;
for (int i = 0; i < NIDX; i++) { for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx( auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i]; out_idx += idx_val * out_strides[i];
} }
@ -43,20 +41,11 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
const constant int* out_shape [[buffer(3)]], \ const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \ const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \ const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) \ IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \ const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\ \
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \ return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \ updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
out, \
out_shape, \
out_strides, \
upd_size, \
idx_buffers, \
gid); \
\
} }
template <typename T, typename IdxT, typename Op, int NIDX> template <typename T, typename IdxT, typename Op, int NIDX>
@ -73,7 +62,6 @@ METAL_FUNC void scatter_impl(
const constant int* axes [[buffer(10)]], const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices, const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
Op op; Op op;
auto ind_idx = gid.y; auto ind_idx = gid.y;
auto ind_offset = gid.x; auto ind_offset = gid.x;
@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
&indices.strides[indices.ndim * i], &indices.strides[indices.ndim * i],
indices.ndim); indices.ndim);
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx( auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax]; out_idx += idx_val * out_strides[ax];
} }
@ -97,7 +84,8 @@ METAL_FUNC void scatter_impl(
out_idx += out_offset; out_idx += out_offset;
} }
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx); op.atomic_update(out, updates[upd_idx], out_idx);
} }
@ -117,14 +105,9 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
const constant int* idx_shapes [[buffer(11)]], \ const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \ const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \ const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \ IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \ Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \ {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\ \
return scatter_impl<T, IdxT, Op, NIDX>( \ return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \ updates, \
@ -145,25 +128,17 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \ make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n) make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
make_scatter(0) make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
make_scatter(1) make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
make_scatter(2) make_scatter(9) make_scatter(10)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
// Scatter instantiations // Scatter instantiations
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ #define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \ template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \ scatter<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \ const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \ device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \ const constant int* upd_shape [[buffer(3)]], \
@ -177,32 +152,33 @@ template [[host_name("scatter" name "_" #nidx)]] \
const constant int* idx_shapes [[buffer(11)]], \ const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \ const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \ const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \ IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ #define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \ template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \ scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \ const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \ device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \ const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \ const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \ const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) \ IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
uint2 gid [[thread_position_in_grid]]);
// clang-format off
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ #define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \ instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
// clang-format off
// Special case NINDEX=0 // Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \ #define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \ instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \ instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \ instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \ instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
// clang-format off
#define instantiate_scatter3(name, type, ind_type, op_type) \ #define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \ instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \ instantiate_scatter4(name, type, ind_type, op_type, 2) \
@ -213,15 +189,17 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
instantiate_scatter4(name, type, ind_type, op_type, 7) \ instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \ instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \ instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10) instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
// clang-format off
#define instantiate_scatter2(name, type, ind_type) \ #define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \ instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \ instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \ instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \ instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>) instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
// clang-format off
#define instantiate_scatter(name, type) \ #define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \ instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \ instantiate_scatter2(#name "uint8", type, uint8_t) \
@ -231,8 +209,9 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
instantiate_scatter2(#name "int8", type, int8_t) \ instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \ instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \ instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t) instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
// clang-format off
// TODO uint64 and int64 unsupported // TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool) instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t) instantiate_scatter_nd0(uint8, uint8_t)
@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t) instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half) instantiate_scatter(float16, half)
instantiate_scatter(float32, float) instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t) instantiate_scatter(bfloat16, bfloat16_t) // clang-format on

View File

@ -198,7 +198,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
} }
} }
// clang-format off
#define instantiate_softmax(name, itype) \ #define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \ template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \ softmax_single_row<itype>( \
@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t) instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half) instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t) instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on
// clang-format on

View File

@ -11,7 +11,8 @@
using namespace metal; using namespace metal;
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub // Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Thread-level sort // Thread-level sort
@ -43,7 +44,6 @@ struct ThreadSort {
static METAL_FUNC void sort( static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD], thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) { thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op; CompareOp op;
MLX_MTL_LOOP_UNROLL MLX_MTL_LOOP_UNROLL
@ -56,7 +56,6 @@ struct ThreadSort {
} }
} }
} }
} }
}; };
@ -72,14 +71,14 @@ template <
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp> typename CompareOp>
struct BlockMergeSort { struct BlockMergeSort {
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>; using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition( static METAL_FUNC int merge_partition(
const threadgroup val_t* As, const threadgroup val_t* As,
const threadgroup val_t* Bs, const threadgroup val_t* Bs,
short A_sz, short A_sz,
short B_sz, short B_sz,
short sort_md) { short sort_md) {
CompareOp op; CompareOp op;
short A_st = max(0, sort_md - B_sz); short A_st = max(0, sort_md - B_sz);
@ -98,7 +97,6 @@ struct BlockMergeSort {
} }
return A_ed; return A_ed;
} }
static METAL_FUNC void merge_step( static METAL_FUNC void merge_step(
@ -110,7 +108,6 @@ struct BlockMergeSort {
short B_sz, short B_sz,
thread val_t (&vals)[N_PER_THREAD], thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) { thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op; CompareOp op;
short a_idx = 0; short a_idx = 0;
short b_idx = 0; short b_idx = 0;
@ -126,7 +123,6 @@ struct BlockMergeSort {
b_idx += short(pred); b_idx += short(pred);
a_idx += short(!pred); a_idx += short(!pred);
} }
} }
static METAL_FUNC void sort( static METAL_FUNC void sort(
@ -134,7 +130,6 @@ struct BlockMergeSort {
threadgroup idx_t* tgp_idxs [[threadgroup(1)]], threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis, int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location // Get thread location
int idx = lid.x * N_PER_THREAD; int idx = lid.x * N_PER_THREAD;
@ -154,7 +149,8 @@ struct BlockMergeSort {
} }
// Do merges using threadgroup memory // Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) { for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory // Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) { for (int i = 0; i < N_PER_THREAD; ++i) {
@ -189,12 +185,7 @@ struct BlockMergeSort {
// of size N_PER_THREAD for each merge lane i // of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted // C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane; int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition( int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As,
Bs,
A_sz,
B_sz,
sort_md);
As += partition; As += partition;
Bs += sort_md - partition; Bs += sort_md - partition;
@ -202,20 +193,13 @@ struct BlockMergeSort {
A_sz -= partition; A_sz -= partition;
B_sz -= sort_md - partition; B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; const threadgroup idx_t* As_idx =
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers // Merge starting at the partition and store results in thread registers
merge_step( merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
As,
Bs,
As_idx,
Bs_idx,
A_sz,
B_sz,
thread_vals,
thread_idxs);
} }
// Write out to shared memory // Write out to shared memory
@ -263,14 +247,14 @@ struct KernelMergeSort {
threadgroup idx_t* tgp_idxs, threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
inp += tid.y * stride_segment_axis; inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis; out += tid.y * stride_segment_axis;
// Copy into threadgroup memory // Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init); tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) { if (ARG_SORT) {
tgp_idxs[i] = i; tgp_idxs[i] = i;
} }
@ -308,8 +292,8 @@ template <
const constant int& stride_segment_axis [[buffer(4)]], const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t; using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t; using idx_t = typename sort_kernel::idx_t;
@ -339,7 +323,6 @@ template <
tid, tid,
lid); lid);
} }
} }
constant constexpr const int zero_helper = 0; constant constexpr const int zero_helper = 0;
@ -360,8 +343,8 @@ template <
const device size_t* nc_strides [[buffer(6)]], const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t; using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t; using idx_t = typename sort_kernel::idx_t;
@ -395,17 +378,17 @@ template <
tid, tid,
lid); lid);
} }
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Instantiations // Instantiations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define instantiate_block_sort( \
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \ name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \ template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \ "_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \ const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \ const constant int& size_sorted_axis [[buffer(2)]], \
@ -413,8 +396,9 @@ template <
const constant int& stride_segment_axis [[buffer(4)]], \ const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \ uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \ template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \ "_nc")]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \ const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \ device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \ const constant int& size_sorted_axis [[buffer(2)]], \
@ -426,15 +410,19 @@ template <
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn) instantiate_block_sort( \
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \ #define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn) instantiate_block_sort( \
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
// clang-format off
#define instantiate_block_sort_tn(itname, itype, bn) \ #define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \ instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
// clang-format off
#define instantiate_block_sort_bn(itname, itype) \ #define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \ instantiate_block_sort_tn(itname, itype, 256) \
@ -448,14 +436,14 @@ instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t) instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half) instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float) instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t) instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_block_sort_long(itname, itype) \ #define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t) instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t) instantiate_block_sort_long(int64, int64_t) // clang-format on
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort // Multi block merge sort
@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
threadgroup idx_t* tgp_idxs, threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK; int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory // Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i; int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init); tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx; tgp_idxs[i] = idx;
} }
@ -523,7 +511,6 @@ struct KernelMultiBlockMergeSort {
int A_sz, int A_sz,
int B_sz, int B_sz,
int sort_md) { int sort_md) {
CompareOp op; CompareOp op;
int A_st = max(0, sort_md - B_sz); int A_st = max(0, sort_md - B_sz);
@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
} }
return A_ed; return A_ed;
} }
}; };
@ -563,8 +549,12 @@ template <
const device size_t* nc_strides [[buffer(7)]], const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx; inp += block_idx;
@ -592,7 +582,8 @@ template <
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD> short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition( [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]], device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]], const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]], const device idx_t* dev_idxs [[buffer(2)]],
@ -601,7 +592,6 @@ template <
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) { uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<
val_t, val_t,
idx_t, idx_t,
@ -627,14 +617,9 @@ template <
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition( int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
block_partitions[lid.x] = A_st + partition; block_partitions[lid.x] = A_st + partition;
} }
template < template <
@ -644,7 +629,8 @@ template <
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp = LessThan<val_t>> typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge( [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]], const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]], const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]], const device idx_t* dev_idxs_in [[buffer(2)]],
@ -655,7 +641,6 @@ template <
const constant int& num_tiles [[buffer(7)]], const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<
val_t, val_t,
idx_t, idx_t,
@ -681,7 +666,9 @@ template <
int A_st = block_partitions[block_idx + 0]; int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1]; int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if ((block_idx % merge_tiles) == merge_tiles - 1) { if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
@ -697,8 +684,10 @@ template <
for (int i = 0; i < N_PER_THREAD; i++) { for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x; int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) { if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz]; thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz]; : dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else { } else {
thread_vals[i] = CompareOp::init; thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0; thread_idxs[i] = 0;
@ -720,11 +709,7 @@ template <
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition( int A_st_local = block_sort_t::merge_partition(
tgp_vals, tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
tgp_vals + A_sz,
A_sz,
B_sz,
sort_md_local);
int A_ed_local = A_sz; int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local; int B_st_local = sort_md_local - A_st_local;
@ -761,12 +746,13 @@ template <
dev_idxs_out[idx] = tgp_idxs[i]; dev_idxs_out[idx] = tgp_idxs[i];
} }
} }
} }
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \ #define instantiate_multi_block_sort( \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ vtname, vtype, itname, itype, arg_sort, bn, tn) \
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \ template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \ const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \ device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \ device itype* out_idxs [[buffer(2)]], \
@ -777,8 +763,9 @@ template <
const device size_t* nc_strides [[buffer(7)]], \ const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \ uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \ "_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \ device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \ const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \ const device itype* dev_idxs [[buffer(2)]], \
@ -787,8 +774,9 @@ template <
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \ uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \ "_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \ const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \ const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \ const device itype* dev_idxs_in [[buffer(2)]], \
@ -800,6 +788,7 @@ template <
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_multi_block_sort_base(vtname, vtype) \ #define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t) instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half) instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float) instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_multi_block_sort_long(vtname, vtype) \ #define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t) instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t) instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@ -4,13 +4,14 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal; using namespace metal;
template <typename T, template <
typename T,
int BM, int BM,
int BN, int BN,
int BK, int BK,
@ -18,7 +19,8 @@ template <typename T,
int WN, int WN,
int N_CHANNELS = 0, int N_CHANNELS = 0,
bool SMALL_FILTER = false> bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]], const device T* B [[buffer(1)]],
device T* C [[buffer(2)]], device T* C [[buffer(2)]],
@ -28,8 +30,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel; using namespace mlx::steel;
(void)lid; (void)lid;
@ -56,7 +56,13 @@ template <typename T,
// Go to small channel specialization // Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels< Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>, T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader // Else go to general loader
typename metal::conditional_t< typename metal::conditional_t<
@ -65,14 +71,21 @@ template <typename T,
// Go to small filter specialization // Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter< Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>, T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization // Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter< Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a> T,
> BM,
>; BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader // Weight loader
using loader_b_t = typename metal::conditional_t< using loader_b_t = typename metal::conditional_t<
@ -81,11 +94,16 @@ template <typename T,
// Go to small channel specialization // Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels< Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>, T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader // Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b> Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
>;
using mma_t = BlockMMA< using mma_t = BlockMMA<
T, T,
@ -123,8 +141,10 @@ template <typename T,
const int2 offsets_b(0, c_col); const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); loader_a_t loader_a(
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation // Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
@ -152,12 +172,24 @@ template <typename T,
short tgp_bm = min(BM, gemm_params->M - c_row); short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col); short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm)); mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
} }
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \ #define instantiate_implicit_conv_2d( \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \ name, \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \ itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
channel_name, \
n_channels, \
filter_name, \
small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
"_filter_" #filter_name)]] [[kernel]] void \
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \ device itype* C [[buffer(2)]], \
@ -168,22 +200,25 @@ template <typename T,
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \ #define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -4,15 +4,16 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;
template <typename T, template <
typename T,
int BM, int BM,
int BN, int BN,
int BK, int BK,
@ -20,7 +21,8 @@ template <typename T,
int WN, int WN,
typename AccumType = float, typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>> typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general( [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]], const device T* B [[buffer(1)]],
device T* C [[buffer(2)]], device T* C [[buffer(2)]],
@ -33,7 +35,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid; (void)lid;
constexpr bool transpose_a = false; constexpr bool transpose_a = false;
@ -51,12 +52,12 @@ template <typename T,
constexpr short tgp_size = WM * WN * 32; constexpr short tgp_size = WM * WN * 32;
// Input loader // Input loader
using loader_a_t = Conv2DInputBlockLoaderGeneral< using loader_a_t =
T, BM, BN, BK, tgp_size, tgp_padding_a>; Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader // Weight loader
using loader_b_t = Conv2DWeightBlockLoaderGeneral< using loader_b_t =
T, BM, BN, BK, tgp_size, tgp_padding_b>; Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA< using mma_t = BlockMMA<
T, T,
@ -103,13 +104,32 @@ template <typename T,
const int2 offsets_b(0, c_col); const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); loader_a_t loader_a(
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation // Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) { for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -143,22 +163,24 @@ template <typename T,
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) { for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride; int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw; int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw; int hw = cm % jump_params->adj_out_hw;
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; int oh =
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2]; oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) { for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C // Get accumulated result and associated offset in C
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements(); thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride); int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C // Apply epilogue and output C
@ -170,16 +192,16 @@ template <typename T,
C[offset + 1] = Epilogue::apply(accum[1]); C[offset + 1] = Epilogue::apply(accum[1]);
} }
} }
} }
} }
} }
} }
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ template \
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \ [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \ device itype* C [[buffer(2)]], \
@ -196,14 +218,16 @@ template <typename T,
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \ #define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;
@ -11,7 +11,8 @@ using namespace mlx::steel;
// GEMM kernels // GEMM kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
typename T,
int BM, int BM,
int BN, int BN,
int BK, int BK,
@ -32,8 +33,18 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a]; threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
@ -57,20 +68,34 @@ template <typename T,
D += params->batch_stride_d * tid.z; D += params->batch_stride_d * tid.z;
gemm_kernel::run( gemm_kernel::run(
A, B, D, A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations // GEMM kernel initializations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ #define instantiate_gemm( \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ tname, \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \ trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname)]] [[kernel]] void \
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \ device itype* D [[buffer(3)]], \
@ -82,26 +107,30 @@ template <typename T,
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -10,7 +10,8 @@ using namespace mlx::steel;
// GEMM kernels // GEMM kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
typename T,
int BM, int BM,
int BN, int BN,
int BK, int BK,
@ -35,15 +36,23 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler // Pacifying compiler
(void)lid; (void)lid;
using gemm_kernel = using gemm_kernel = GEMMKernel<
GEMMKernel<T, T, BM, BN, BK, WM, WN, T,
transpose_a, transpose_b, T,
MN_aligned, K_aligned, BM,
AccumType, Epilogue>; BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned,
AccumType,
Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t; using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t; using loader_b_t = typename gemm_kernel::loader_b_t;
@ -59,7 +68,12 @@ template <typename T,
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
ulong3 batch_offsets = elem_to_loc_broadcast( ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim); tid.z,
batch_shape,
A_bstrides,
B_bstrides,
C_bstrides,
params->batch_ndim);
A += batch_offsets.x; A += batch_offsets.x;
B += batch_offsets.y; B += batch_offsets.y;
@ -140,7 +154,8 @@ template <typename T,
} }
// Store results to device memory // Store results to device memory
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return; return;
} }
@ -164,7 +179,8 @@ template <typename T,
leftover_bk, leftover_bk,
LoopAlignment<true, true, K_aligned>{}); LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return; return;
} else if (tgp_bn == BN) { } else if (tgp_bn == BN) {
@ -181,8 +197,11 @@ template <typename T,
LoopAlignment<false, true, K_aligned>{}); LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D,
C, addmm_params->ldc, addmm_params->fdc, params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
@ -200,8 +219,11 @@ template <typename T,
LoopAlignment<true, false, K_aligned>{}); LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D,
C, addmm_params->ldc, addmm_params->fdc, params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
@ -219,8 +241,11 @@ template <typename T,
LoopAlignment<false, false, K_aligned>{}); LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D,
C, addmm_params->ldc, addmm_params->fdc, params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
} }
@ -231,9 +256,41 @@ template <typename T,
// GEMM kernel initializations // GEMM kernel initializations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \ #define instantiate_gemm( \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \ tname, \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \ trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
ep_name, \
epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
addmm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
float, \
epilogue<itype, float>>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
const device itype* C [[buffer(2)]], \ const device itype* C [[buffer(2)]], \
@ -247,30 +304,35 @@ template <typename T,
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ #define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;
@ -11,7 +11,8 @@ using namespace mlx::steel;
// GEMM kernels // GEMM kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
typename T,
int BM, int BM,
int BN, int BN,
int BK, int BK,
@ -22,7 +23,8 @@ template <typename T,
bool MN_aligned, bool MN_aligned,
bool K_aligned, bool K_aligned,
bool has_operand_mask = false> bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm( [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]], const device T* B [[buffer(1)]],
device T* D [[buffer(3)]], device T* D [[buffer(3)]],
@ -37,11 +39,21 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler // Appease the compiler
(void)lid; (void)lid;
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) + const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1)); ((tid.x) & ((1 << params->swizzle_log) - 1));
@ -52,15 +64,23 @@ template <typename T,
} }
if (params->batch_ndim > 1) { if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim; const constant size_t* mask_batch_strides =
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_operand_mask) { if (has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim; const constant size_t* mask_strides_lhs =
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast( ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim); tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x; lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y; rhs_mask += batch_offsets.y;
@ -99,7 +119,6 @@ template <typename T,
B += transpose_b ? c_col * params->ldb : c_col; B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col; D += c_row * params->ldd + c_col;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
// Write zeros and return // Write zeros and return
@ -151,8 +170,10 @@ template <typename T,
threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread typename gemm_kernel::loader_a_t loader_a(
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop // MNK aligned loop
@ -161,9 +182,10 @@ template <typename T,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask || if (!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && (lhs_mask
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup // Load elements into threadgroup
loader_a.load_unsafe(); loader_a.load_unsafe();
loader_b.load_unsafe(); loader_b.load_unsafe();
@ -172,7 +194,6 @@ template <typename T,
// Multiply and accumulate threadgroup elements // Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs); mma_op.mma(As, Bs);
} }
// Prepare for next iteration // Prepare for next iteration
@ -184,11 +205,12 @@ template <typename T,
// Loop tail // Loop tail
if (!K_aligned) { if (!K_aligned) {
if (!has_operand_mask || if (!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && (lhs_mask
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK; int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
@ -199,7 +221,6 @@ template <typename T,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs); mma_op.mma(As, Bs);
} }
} }
@ -224,9 +245,10 @@ template <typename T,
for (int k = 0; k < gemm_k_iterations; k++) { for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask || if (!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && (lhs_mask
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup // Load elements into threadgroup
if (M_aligned) { if (M_aligned) {
loader_a.load_unsafe(); loader_a.load_unsafe();
@ -244,7 +266,6 @@ template <typename T,
// Multiply and accumulate threadgroup elements // Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs); mma_op.mma(As, Bs);
} }
// Prepare for next iteration // Prepare for next iteration
@ -256,9 +277,11 @@ template <typename T,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask || if (!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && (lhs_mask
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last = short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last = short2 tile_dims_B_last =
@ -270,7 +293,6 @@ template <typename T,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs); mma_op.mma(As, Bs);
} }
} }
@ -286,9 +308,41 @@ template <typename T,
// GEMM kernel initializations // GEMM kernel initializations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \ #define instantiate_gemm( \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \ tname, \
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \ trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
omname, \
op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname \
"_op_mask_" #omname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
op_mask>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \ device itype* D [[buffer(3)]], \
@ -304,26 +358,31 @@ template <typename T,
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -10,7 +10,8 @@ using namespace mlx::steel;
// GEMM kernels // GEMM kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, template <
typename T,
typename U, typename U,
int BM, int BM,
int BN, int BN,
@ -30,10 +31,20 @@ template <typename T,
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid; (void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>; using gemm_kernel = GEMMKernel<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t; using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t; using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t; using mma_t = typename gemm_kernel::mma_t;
@ -54,9 +65,12 @@ template <typename T,
const int c_col = tid_x * BN; const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z; const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda); A += transpose_a ? (c_row + k_start * params->lda)
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb); : (k_start + c_row * params->lda);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col); B += transpose_b ? (k_start + c_col * params->ldb)
: (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) +
(c_row * params->ldc + c_col);
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
@ -124,7 +138,8 @@ template <typename T,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) { if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK; int gemm_k_iter_remaining =
(params->K - (k_start + params->split_k_partition_size)) / BK;
if (!K_aligned || gemm_k_iter_remaining > 0) if (!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop( gemm_kernel::gemm_loop(
As, As,
@ -150,9 +165,38 @@ template <typename T,
// GEMM kernel initializations // GEMM kernel initializations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ #define instantiate_gemm( \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ tname, \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \ trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
gemm_splitk< \
itype, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned>( \
const device itype* A [[buffer(0)]], \ const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \ const device itype* B [[buffer(1)]], \
device otype* C [[buffer(2)]], \ device otype* C [[buffer(2)]], \
@ -162,34 +206,39 @@ template <typename T,
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel // Split k accumulation kernel
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename AccT, template <
typename AccT,
typename OutT, typename OutT,
typename Epilogue = TransformNone<OutT, AccT>> typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum( [[kernel]] void gemm_splitk_accum(
@ -199,7 +248,6 @@ template <typename AccT,
const constant int& partition_stride [[buffer(3)]], const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]], const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C // Ajust D and C
D += gid.x + gid.y * ldd; D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd; C_split += gid.x + gid.y * ldd;
@ -214,10 +262,10 @@ template <typename AccT,
// Write output // Write output
D[0] = Epilogue::apply(out); D[0] = Epilogue::apply(out);
} }
template <typename AccT, template <
typename AccT,
typename OutT, typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>> typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby( [[kernel]] void gemm_splitk_accum_axpby(
@ -232,7 +280,6 @@ template <typename AccT,
const constant float& alpha [[buffer(8)]], const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]], const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C // Ajust D and C
C += gid.x * fdc + gid.y * ldc; C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd; D += gid.x + gid.y * ldd;
@ -249,20 +296,21 @@ template <typename AccT,
// Write output // Write output
Epilogue op(alpha, beta); Epilogue op(alpha, beta);
D[0] = op.apply(out, *C); D[0] = op.apply(out, *C);
} }
#define instantiate_accum(oname, otype, aname, atype) \ #define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \ template [[host_name("steel_gemm_splitk_accum_" #oname \
[[kernel]] void gemm_splitk_accum<atype, otype>( \ "_" #aname)]] [[kernel]] void \
gemm_splitk_accum<atype, otype>( \
const device atype* C_split [[buffer(0)]], \ const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \ device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \ const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \ const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \ const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \ uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \ template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \ "_axpby")]] [[kernel]] void \
gemm_splitk_accum_axpby<atype, otype>( \
const device atype* C_split [[buffer(0)]], \ const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \ device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \ const constant int& k_partitions [[buffer(2)]], \
@ -275,6 +323,7 @@ template <typename AccT,
const constant float& beta [[buffer(9)]], \ const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]); uint2 gid [[thread_position_in_grid]]);
// clang-format off
instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float); instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); instantiate_accum(float32, float, float32, float); // clang-format on

View File

@ -3,9 +3,9 @@
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename Op> template <typename T, typename Op>
[[kernel]] void ternary_op_v( [[kernel]] void ternary_op_v(
@ -65,7 +65,8 @@ template <typename T, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides); auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
constant const size_t c_strides[DIM], constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides); auto idx =
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
} }
@ -99,23 +102,22 @@ template <typename T, typename Op>
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
} }
#define instantiate_ternary_v(name, type, op) \ #define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
[[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
device type* d, \ device type* d, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]);
#define instantiate_ternary_g(name, type, op) \ #define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
[[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -126,11 +128,11 @@ template <typename T, typename Op>
constant const size_t* c_strides, \ constant const size_t* c_strides, \
constant const int& ndim, \ constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_dim(name, type, op, dims) \ #define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] [[kernel]] void \
[[kernel]] void ternary_op_g_nd<type, op, dims>( \ ternary_op_g_nd<type, op, dims>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -140,11 +142,11 @@ template <typename T, typename Op>
constant const size_t b_strides[dims], \ constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \ constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_nd(name, type, op) \ #define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] [[kernel]] void \
[[kernel]] void ternary_op_g_nd1<type, op>( \ ternary_op_g_nd1<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -153,8 +155,8 @@ template <typename T, typename Op>
constant const size_t& b_strides, \ constant const size_t& b_strides, \
constant const size_t& c_strides, \ constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] [[kernel]] void \
[[kernel]] void ternary_op_g_nd2<type, op>( \ ternary_op_g_nd2<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -164,8 +166,8 @@ template <typename T, typename Op>
constant const size_t c_strides[2], \ constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] [[kernel]] void \
[[kernel]] void ternary_op_g_nd3<type, op>( \ ternary_op_g_nd3<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -176,13 +178,15 @@ template <typename T, typename Op>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \ instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) \ instantiate_ternary_g_dim(name, type, op, 5)
// clang-format off
#define instantiate_ternary_all(name, tname, type, op) \ #define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \ instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \ instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) \ instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
// clang-format off
#define instantiate_ternary_types(name, op) \ #define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \ instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \ instantiate_ternary_all(name, uint8, uint8_t, op) \
@ -196,6 +200,6 @@ template <typename T, typename Op>
instantiate_ternary_all(name, float16, half, op) \ instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \ instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) \ instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
instantiate_ternary_types(select, Select) instantiate_ternary_types(select, Select)

View File

@ -23,15 +23,13 @@ template <typename T, typename Op>
} }
#define instantiate_unary_v(name, type, op) \ #define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
[[kernel]] void unary_op_v<type, op>( \
device const type* in, \ device const type* in, \
device type* out, \ device type* out, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \ #define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] \ template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
[[kernel]] void unary_op_g<type, op>( \
device const type* in, \ device const type* in, \
device type* out, \ device type* out, \
device const int* in_shape, \ device const int* in_shape, \
@ -39,15 +37,18 @@ template <typename T, typename Op>
device const int& ndim, \ device const int& ndim, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_unary_all(name, tname, type, op) \ #define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \ instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op) instantiate_unary_g("g" #name #tname, type, op) // clang-format on
// clang-format off
#define instantiate_unary_float(name, op) \ #define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \ instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \ instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \ instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
// clang-format off
#define instantiate_unary_types(name, op) \ #define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \ instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \ instantiate_unary_all(name, uint8, uint8_t, op) \
@ -58,8 +59,9 @@ template <typename T, typename Op>
instantiate_unary_all(name, int16, int16_t, op) \ instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \ instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \ instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op) instantiate_unary_float(name, op) // clang-format on
// clang-format off
instantiate_unary_types(abs, Abs) instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos) instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh) instantiate_unary_float(arccosh, ArcCosh)
@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh) instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(round, complex64, complex64_t, Round) instantiate_unary_all(round, complex64, complex64_t, Round)
instantiate_unary_all(lnot, bool_, bool, LogicalNot) instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on