mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump * add guards * update * more guards * more guards * smakk fix * Refactor instantiation of ternary types in ternary.metal * fix scan.metal
This commit is contained in:
parent
8db7161c94
commit
a30e7ed2da
@ -1,11 +1,11 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.3
|
rev: v18.1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.3.0
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
|
@ -36,8 +36,8 @@ template <typename T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_axpby(type_name, type) \
|
#define instantiate_axpby(type_name, type) \
|
||||||
template [[host_name("axpby_general_" #type_name)]] \
|
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_general<type>( \
|
axpby_general<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
@ -48,8 +48,8 @@ template <typename T>
|
|||||||
constant const size_t* y_strides [[buffer(7)]], \
|
constant const size_t* y_strides [[buffer(7)]], \
|
||||||
constant const int& ndim [[buffer(8)]], \
|
constant const int& ndim [[buffer(8)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_contiguous<type>( \
|
axpby_contiguous<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
|
@ -12,13 +12,13 @@ template <typename T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_arange(tname, type) \
|
#define instantiate_arange(tname, type) \
|
||||||
template [[host_name("arange" #tname)]] \
|
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||||
[[kernel]] void arange<type>( \
|
|
||||||
constant const type& start, \
|
constant const type& start, \
|
||||||
constant const type& step, \
|
constant const type& step, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_arange(uint8, uint8_t)
|
instantiate_arange(uint8, uint8_t)
|
||||||
instantiate_arange(uint16, uint16_t)
|
instantiate_arange(uint16, uint16_t)
|
||||||
instantiate_arange(uint32, uint32_t)
|
instantiate_arange(uint32, uint32_t)
|
||||||
@ -29,4 +29,4 @@ instantiate_arange(int32, int32_t)
|
|||||||
instantiate_arange(int64, int64_t)
|
instantiate_arange(int64, int64_t)
|
||||||
instantiate_arange(float16, half)
|
instantiate_arange(float16, half)
|
||||||
instantiate_arange(float32, float)
|
instantiate_arange(float32, float)
|
||||||
instantiate_arange(bfloat16, bfloat16_t)
|
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
@ -18,7 +18,8 @@ struct ArgMin {
|
|||||||
static constexpr constant U init = Limits<U>::max;
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
if (best.val > current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
return current;
|
return current;
|
||||||
} else {
|
} else {
|
||||||
return best;
|
return best;
|
||||||
@ -26,7 +27,8 @@ struct ArgMin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
IndexValPair<U>
|
||||||
|
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] < best.val) {
|
if (vals[i] < best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@ -42,7 +44,8 @@ struct ArgMax {
|
|||||||
static constexpr constant U init = Limits<U>::min;
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
if (best.val < current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
return current;
|
return current;
|
||||||
} else {
|
} else {
|
||||||
return best;
|
return best;
|
||||||
@ -50,7 +53,8 @@ struct ArgMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
IndexValPair<U>
|
||||||
|
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] > best.val) {
|
if (vals[i] > best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@ -64,12 +68,9 @@ struct ArgMax {
|
|||||||
template <typename U>
|
template <typename U>
|
||||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||||
return IndexValPair<U>{
|
return IndexValPair<U>{
|
||||||
simd_shuffle_down(data.index, delta),
|
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||||
simd_shuffle_down(data.val, delta)
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename Op, int N_READS>
|
template <typename T, typename Op, int N_READS>
|
||||||
[[kernel]] void arg_reduce_general(
|
[[kernel]] void arg_reduce_general(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
|
|||||||
uint simd_size [[threads_per_simdgroup]],
|
uint simd_size [[threads_per_simdgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||||
// and stride are provided in axis_stride and axis_size.
|
// and stride are provided in axis_stride and axis_size.
|
||||||
//
|
//
|
||||||
@ -161,8 +161,8 @@ template <typename T, typename Op, int N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
arg_reduce_general<itype, op<itype>, 4>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device uint32_t* out [[buffer(1)]], \
|
device uint32_t* out [[buffer(1)]], \
|
||||||
const device int* shape [[buffer(2)]], \
|
const device int* shape [[buffer(2)]], \
|
||||||
@ -178,6 +178,7 @@ template <typename T, typename Op, int N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_arg_reduce(name, itype) \
|
#define instantiate_arg_reduce(name, itype) \
|
||||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||||
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
|
|||||||
instantiate_arg_reduce(int64, int64_t)
|
instantiate_arg_reduce(int64, int64_t)
|
||||||
instantiate_arg_reduce(float16, half)
|
instantiate_arg_reduce(float16, half)
|
||||||
instantiate_arg_reduce(float32, float)
|
instantiate_arg_reduce(float32, float)
|
||||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,16 +115,16 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||||
template [[host_name(name)]] \
|
template \
|
||||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
binary_op_g_nd<itype, otype, op, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -133,16 +135,16 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
binary_op_g_nd1<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
binary_op_g_nd2<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -150,8 +152,8 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
binary_op_g_nd3<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -162,10 +164,8 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op) \
|
#define instantiate_binary_g(name, itype, otype, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -176,14 +176,16 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_integer(name, op) \
|
#define instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||||
@ -192,19 +194,22 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_float(name, op) \
|
#define instantiate_binary_float(name, op) \
|
||||||
instantiate_binary_all(name, float16, half, half, op) \
|
instantiate_binary_all(name, float16, half, half, op) \
|
||||||
instantiate_binary_all(name, float32, float, float, op) \
|
instantiate_binary_all(name, float32, float, float, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types(name, op) \
|
#define instantiate_binary_types(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_integer(name, op) \
|
instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||||
instantiate_binary_float(name, op)
|
instantiate_binary_float(name, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types_bool(name, op) \
|
#define instantiate_binary_types_bool(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||||
@ -218,8 +223,9 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, float16, half, bool, op) \
|
instantiate_binary_all(name, float16, half, bool, op) \
|
||||||
instantiate_binary_all(name, float32, float, bool, op) \
|
instantiate_binary_all(name, float32, float, bool, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_binary_types(add, Add)
|
instantiate_binary_types(add, Add)
|
||||||
instantiate_binary_types(div, Divide)
|
instantiate_binary_types(div, Divide)
|
||||||
instantiate_binary_types_bool(eq, Equal)
|
instantiate_binary_types_bool(eq, Equal)
|
||||||
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
|||||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||||
instantiate_binary_integer(left_shift, LeftShift)
|
instantiate_binary_integer(left_shift, LeftShift)
|
||||||
instantiate_binary_integer(right_shift, RightShift)
|
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||||
|
@ -3,23 +3,37 @@
|
|||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
struct FloorDivide {
|
struct FloorDivide {
|
||||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
template <typename T>
|
||||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
T operator()(T x, T y) {
|
||||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
return x / y;
|
||||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
}
|
||||||
|
template <>
|
||||||
|
float operator()(float x, float y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
half operator()(half x, half y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Remainder {
|
struct Remainder {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
auto r = x % y;
|
auto r = x % y;
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
r += y;
|
r += y;
|
||||||
@ -34,7 +48,8 @@ struct Remainder {
|
|||||||
}
|
}
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
d[index] = Op2()(a[0], b[0]);
|
d[index] = Op2()(a[0], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
template <typename T, typename U, typename Op1, typename Op2>
|
||||||
[[kernel]] void binary_op_ss(
|
[[kernel]] void binary_op_ss(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||||
}
|
}
|
||||||
@ -180,8 +196,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -189,8 +205,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -201,9 +217,10 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -211,8 +228,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -221,8 +238,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -232,12 +249,11 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
binary_op_g<itype, otype, op2, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -249,19 +265,22 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_float(name, op1, op2) \
|
#define instantiate_binary_float(name, op1, op2) \
|
||||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types(name, op1, op2) \
|
#define instantiate_binary_types(name, op1, op2) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||||
@ -275,4 +294,4 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||||
instantiate_binary_float(name, op1, op2)
|
instantiate_binary_float(name, op1, op2)
|
||||||
|
|
||||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_simdgroup_matrix>
|
#include <metal_simdgroup_matrix>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
@ -23,12 +21,13 @@ template <typename T, int N>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||||
uint3 gid [[thread_position_in_grid]]) {
|
uint3 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
int filter_size = params->C;
|
int filter_size = params->C;
|
||||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
filter_size *= params->wS[i];
|
||||||
|
|
||||||
int out_pixels = 1;
|
int out_pixels = 1;
|
||||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
out_pixels *= params->oS[i];
|
||||||
|
|
||||||
// Set out
|
// Set out
|
||||||
out += gid.z * filter_size + gid.y * (params->C);
|
out += gid.z * filter_size + gid.y * (params->C);
|
||||||
@ -85,12 +84,13 @@ template <typename T, int N>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||||
uint3 gid [[thread_position_in_grid]]) {
|
uint3 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
int filter_size = params->C;
|
int filter_size = params->C;
|
||||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
filter_size *= params->wS[i];
|
||||||
|
|
||||||
int out_pixels = 1;
|
int out_pixels = 1;
|
||||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
out_pixels *= params->oS[i];
|
||||||
|
|
||||||
// Set out
|
// Set out
|
||||||
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
||||||
@ -142,23 +142,23 @@ template <typename T, int N>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
|
||||||
[[kernel]] void naive_unfold_Nd( \
|
naive_unfold_Nd( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device itype* out [[buffer(1)]], \
|
device itype* out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||||
uint3 gid [[thread_position_in_grid]]); \
|
uint3 gid [[thread_position_in_grid]]); \
|
||||||
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
|
template \
|
||||||
[[kernel]] void naive_unfold_transpose_Nd( \
|
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
|
||||||
|
naive_unfold_transpose_Nd( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device itype* out [[buffer(1)]], \
|
device itype* out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||||
uint3 gid [[thread_position_in_grid]]);
|
uint3 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
|
||||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
|
||||||
instantiate_naive_unfold_nd(name, itype, 3)
|
|
||||||
|
|
||||||
instantiate_naive_unfold_nd_dims(float32, float);
|
instantiate_naive_unfold_nd_dims(float32, float);
|
||||||
instantiate_naive_unfold_nd_dims(float16, half);
|
instantiate_naive_unfold_nd_dims(float16, half);
|
||||||
@ -168,7 +168,8 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
|||||||
/// Slow and naive conv2d kernels
|
/// Slow and naive conv2d kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
const int BM, /* Threadgroup rows (in threads) */
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
const int BN, /* Threadgroup cols (in threads) */
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
const int TM, /* Thread rows (in elements) */
|
const int TM, /* Thread rows (in elements) */
|
||||||
@ -183,7 +184,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)simd_gid;
|
(void)simd_gid;
|
||||||
(void)simd_lid;
|
(void)simd_lid;
|
||||||
|
|
||||||
@ -202,7 +202,6 @@ template <typename T,
|
|||||||
out_w[m] = mm % params.oS[1];
|
out_w[m] = mm % params.oS[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
T in_local[TM];
|
T in_local[TM];
|
||||||
T wt_local[TN];
|
T wt_local[TN];
|
||||||
T out_local[TM * TN] = {T(0)};
|
T out_local[TM * TN] = {T(0)};
|
||||||
@ -210,22 +209,24 @@ template <typename T,
|
|||||||
for (int h = 0; h < params.wS[0]; ++h) {
|
for (int h = 0; h < params.wS[0]; ++h) {
|
||||||
for (int w = 0; w < params.wS[1]; ++w) {
|
for (int w = 0; w < params.wS[1]; ++w) {
|
||||||
for (int c = 0; c < params.C; ++c) {
|
for (int c = 0; c < params.C; ++c) {
|
||||||
|
|
||||||
// Local in
|
// Local in
|
||||||
for (int m = 0; m < TM; m++) {
|
for (int m = 0; m < TM; m++) {
|
||||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||||
|
|
||||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
in_local[m] = valid
|
||||||
|
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||||
|
: T(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load weight
|
// Load weight
|
||||||
for (int n = 0; n < TN; ++n) {
|
for (int n = 0; n < TN; ++n) {
|
||||||
int o = out_o + n;
|
int o = out_o + n;
|
||||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
wt_local[n] = o < params.O
|
||||||
h * params.wt_strides[1] +
|
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||||
w * params.wt_strides[2] + c] : T(0);
|
w * params.wt_strides[2] + c]
|
||||||
|
: T(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate
|
// Accumulate
|
||||||
@ -234,26 +235,27 @@ template <typename T,
|
|||||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int m = 0; m < TM; ++m) {
|
for (int m = 0; m < TM; ++m) {
|
||||||
for (int n = 0; n < TN; ++n) {
|
for (int n = 0; n < TN; ++n) {
|
||||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||||
|
(out_o + n) < params.O)
|
||||||
out[out_h[m] * params.out_strides[1] +
|
out[out_h[m] * params.out_strides[1] +
|
||||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||||
|
out_local[m * TN + n];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiations
|
// Instantiations
|
||||||
|
|
||||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
const device itype* wt [[buffer(1)]], \
|
const device itype* wt [[buffer(1)]], \
|
||||||
device itype* out [[buffer(2)]], \
|
device itype* out [[buffer(2)]], \
|
||||||
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <int M, int R, int S>
|
template <int M, int R, int S>
|
||||||
struct WinogradTransforms {
|
struct WinogradTransforms {};
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct WinogradTransforms<6, 3, 8> {
|
struct WinogradTransforms<6, 3, 8> {
|
||||||
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
|||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||||
int BC = 32,
|
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||||
int BO = 4,
|
winograd_conv_2d_weight_transform(
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
|
||||||
const device T* wt_in [[buffer(0)]],
|
const device T* wt_in [[buffer(0)]],
|
||||||
device T* wt_out [[buffer(1)]],
|
device T* wt_out [[buffer(1)]],
|
||||||
const constant int& C [[buffer(2)]],
|
const constant int& C [[buffer(2)]],
|
||||||
@ -337,7 +334,6 @@ template <typename T,
|
|||||||
uint tid [[threadgroup_position_in_grid]],
|
uint tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
|
|
||||||
// Get lane position in simdgroup
|
// Get lane position in simdgroup
|
||||||
@ -384,8 +380,10 @@ template <typename T,
|
|||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for (int c = 0; c < BC; ++c) {
|
for (int c = 0; c < BC; ++c) {
|
||||||
simdgroup_matrix<T, 8, 8> g;
|
simdgroup_matrix<T, 8, 8> g;
|
||||||
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
g.thread_elements()[0] =
|
||||||
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||||
|
g.thread_elements()[1] =
|
||||||
|
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||||
@ -396,12 +394,12 @@ template <typename T,
|
|||||||
wt_out_0 += BC * O;
|
wt_out_0 += BC * O;
|
||||||
wt_out_1 += BC * O;
|
wt_out_1 += BC * O;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
"_bc" #bc)]] [[kernel]] void \
|
||||||
|
winograd_conv_2d_weight_transform<itype, bc>( \
|
||||||
const device itype* wt_in [[buffer(0)]], \
|
const device itype* wt_in [[buffer(0)]], \
|
||||||
device itype* wt_out [[buffer(1)]], \
|
device itype* wt_out [[buffer(1)]], \
|
||||||
const constant int& C [[buffer(2)]], \
|
const constant int& C [[buffer(2)]], \
|
||||||
@ -410,13 +408,9 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
||||||
int BC,
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
int WM,
|
winograd_conv_2d_input_transform(
|
||||||
int WN,
|
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
|
||||||
const device T* inp_in [[buffer(0)]],
|
const device T* inp_in [[buffer(0)]],
|
||||||
device T* inp_out [[buffer(1)]],
|
device T* inp_out [[buffer(1)]],
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
@ -425,7 +419,6 @@ template <typename T,
|
|||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
@ -456,9 +449,8 @@ template <typename T,
|
|||||||
int bw = M * tid.x + kw;
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
inp_in += tid.z * params.in_strides[0]
|
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||||
+ bh * params.in_strides[1]
|
bw * params.in_strides[2];
|
||||||
+ bw * params.in_strides[2];
|
|
||||||
|
|
||||||
// Pre compute strides
|
// Pre compute strides
|
||||||
int jump_in[TH][TW];
|
int jump_in[TH][TW];
|
||||||
@ -471,11 +463,14 @@ template <typename T,
|
|||||||
|
|
||||||
// inp_out is stored interleaved (A x A x tiles x C)
|
// inp_out is stored interleaved (A x A x tiles x C)
|
||||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
size_t tile_id =
|
||||||
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
size_t ohw_0 = sm * 8 + sn;
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
size_t ohw_1 = sm * 8 + sn + 1;
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
device T* inp_out_0 =
|
||||||
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
device T* inp_out_1 =
|
||||||
|
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
|
||||||
// Prepare shared memory
|
// Prepare shared memory
|
||||||
threadgroup T Is[A][A][BC];
|
threadgroup T Is[A][A][BC];
|
||||||
@ -509,12 +504,12 @@ template <typename T,
|
|||||||
inp_out_0 += BC;
|
inp_out_0 += BC;
|
||||||
inp_out_1 += BC;
|
inp_out_1 += BC;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||||
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
template [[host_name("winograd_conv_2d_input_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
"_bc" #bc)]] [[kernel]] void \
|
||||||
|
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
|
||||||
const device itype* inp_in [[buffer(0)]], \
|
const device itype* inp_in [[buffer(0)]], \
|
||||||
device itype* inp_out [[buffer(1)]], \
|
device itype* inp_out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||||
@ -524,13 +519,9 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
|
||||||
int BO,
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
int WM,
|
winograd_conv_2d_output_transform(
|
||||||
int WN,
|
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
|
||||||
const device T* out_in [[buffer(0)]],
|
const device T* out_in [[buffer(0)]],
|
||||||
device T* out_out [[buffer(1)]],
|
device T* out_out [[buffer(1)]],
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
@ -539,7 +530,6 @@ template <typename T,
|
|||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
@ -572,9 +562,8 @@ template <typename T,
|
|||||||
int bw = M * tid.x + kw;
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
out_out += tid.z * params.out_strides[0]
|
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
|
||||||
+ bh * params.out_strides[1]
|
bw * params.out_strides[2];
|
||||||
+ bw * params.out_strides[2];
|
|
||||||
|
|
||||||
// Pre compute strides
|
// Pre compute strides
|
||||||
int jump_in[TH][TW];
|
int jump_in[TH][TW];
|
||||||
@ -582,24 +571,27 @@ template <typename T,
|
|||||||
for (int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for (int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||||
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
jump_in[h][w] =
|
||||||
|
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// out_in is stored interleaved (A x A x tiles x O)
|
// out_in is stored interleaved (A x A x tiles x O)
|
||||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
size_t tile_id =
|
||||||
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
size_t ohw_0 = sm * 8 + sn;
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
size_t ohw_1 = sm * 8 + sn + 1;
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
const device T* out_in_0 =
|
||||||
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
const device T* out_in_1 =
|
||||||
|
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
|
||||||
// Prepare shared memory
|
// Prepare shared memory
|
||||||
threadgroup T Os[M][M][BO];
|
threadgroup T Os[M][M][BO];
|
||||||
|
|
||||||
// Loop over O
|
// Loop over O
|
||||||
for (int bo = 0; bo < params.O; bo += BO) {
|
for (int bo = 0; bo < params.O; bo += BO) {
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||||
@ -633,12 +625,12 @@ template <typename T,
|
|||||||
out_in_0 += BO;
|
out_in_0 += BO;
|
||||||
out_in_1 += BO;
|
out_in_1 += BO;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||||
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
template [[host_name("winograd_conv_2d_output_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
"_bo" #bo)]] [[kernel]] void \
|
||||||
|
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
|
||||||
const device itype* out_in [[buffer(0)]], \
|
const device itype* out_in [[buffer(0)]], \
|
||||||
device itype* out_out [[buffer(1)]], \
|
device itype* out_out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||||
@ -648,10 +640,12 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_winograd_conv_2d(name, itype) \
|
#define instantiate_winograd_conv_2d(name, itype) \
|
||||||
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_winograd_conv_2d(float32, float);
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
instantiate_winograd_conv_2d(float16, half);
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
@ -49,7 +49,8 @@ template <typename T, typename U>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +78,8 @@ template <typename T, typename U>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,23 +147,22 @@ template <typename T, typename U>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_copy(name, itype, otype, ctype) \
|
#define instantiate_copy(name, itype, otype, ctype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
copy_g_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_" #dims)]] \
|
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
copy_gg_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -168,44 +170,40 @@ template <typename T, typename U>
|
|||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_1")]] \
|
template [[host_name("g" name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
copy_gg_nd1<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_2")]] \
|
template [[host_name("g" name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
copy_gg_nd2<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint2 index [[thread_position_in_grid]]); \
|
uint2 index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_3")]] \
|
template [[host_name("g" name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
copy_gg_nd3<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
@ -214,10 +212,8 @@ template <typename T, typename U>
|
|||||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_copy_g(name, itype, otype) \
|
#define instantiate_copy_g(name, itype, otype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||||
[[kernel]] void copy_g<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -225,8 +221,7 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name)]] \
|
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||||
[[kernel]] void copy_gg<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -235,12 +230,14 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
|
|||||||
instantiate_copy_itype(float16, half)
|
instantiate_copy_itype(float16, half)
|
||||||
instantiate_copy_itype(float32, float)
|
instantiate_copy_itype(float32, float)
|
||||||
instantiate_copy_itype(bfloat16, bfloat16_t)
|
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||||
instantiate_copy_itype(complex64, complex64_t)
|
instantiate_copy_itype(complex64, complex64_t) // clang-format on
|
||||||
|
@ -6,9 +6,8 @@
|
|||||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||||
|
|
||||||
#include <metal_math>
|
|
||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// single threaded radix2 implemetation
|
// single threaded radix2 implemetation
|
||||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
void radix2(
|
||||||
|
int i,
|
||||||
|
int p,
|
||||||
|
int m,
|
||||||
|
threadgroup float2* read_buf,
|
||||||
|
threadgroup float2* write_buf) {
|
||||||
float2 x_0 = read_buf[i];
|
float2 x_0 = read_buf[i];
|
||||||
float2 x_1 = read_buf[i + m];
|
float2 x_1 = read_buf[i + m];
|
||||||
|
|
||||||
@ -53,7 +57,12 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
|||||||
}
|
}
|
||||||
|
|
||||||
// single threaded radix4 implemetation
|
// single threaded radix4 implemetation
|
||||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
void radix4(
|
||||||
|
int i,
|
||||||
|
int p,
|
||||||
|
int m,
|
||||||
|
threadgroup float2* read_buf,
|
||||||
|
threadgroup float2* write_buf) {
|
||||||
float2 x_0 = read_buf[i];
|
float2 x_0 = read_buf[i];
|
||||||
float2 x_1 = read_buf[i + m];
|
float2 x_1 = read_buf[i + m];
|
||||||
float2 x_2 = read_buf[i + 2 * m];
|
float2 x_2 = read_buf[i + 2 * m];
|
||||||
@ -94,7 +103,6 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
|||||||
write_buf[j + 3 * p] = y_3;
|
write_buf[j + 3 * p] = y_3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Each FFT is computed entirely in shared GPU memory.
|
// Each FFT is computed entirely in shared GPU memory.
|
||||||
//
|
//
|
||||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||||
@ -111,7 +119,6 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
|||||||
device float2* out [[buffer(1)]],
|
device float2* out [[buffer(1)]],
|
||||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||||
|
|
||||||
// Index of the DFT in batch
|
// Index of the DFT in batch
|
||||||
int batch_idx = thread_position_in_grid.x * n;
|
int batch_idx = thread_position_in_grid.x * n;
|
||||||
// The index in the DFT we're working on
|
// The index in the DFT we're working on
|
||||||
@ -172,24 +179,21 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||||
template [[host_name("fft_" #name)]] \
|
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
fft<n, radix_2_steps, radix_4_steps>( \
|
||||||
const device float2* in [[buffer(0)]], \
|
const device float2* in [[buffer(0)]], \
|
||||||
device float2* out [[buffer(1)]], \
|
device float2* out [[buffer(1)]], \
|
||||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||||
uint3 threads_per_grid [[threads_per_grid]]);
|
uint3 threads_per_grid [[threads_per_grid]]);
|
||||||
|
|
||||||
|
|
||||||
// Explicitly define kernels for each power of 2.
|
// Explicitly define kernels for each power of 2.
|
||||||
|
// clang-format off
|
||||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||||
instantiate_fft(8, 8, 1, 1)
|
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||||
instantiate_fft(16, 16, 0, 2)
|
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||||
instantiate_fft(32, 32, 1, 2)
|
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||||
instantiate_fft(64, 64, 0, 3)
|
|
||||||
instantiate_fft(128, 128, 1, 3)
|
|
||||||
instantiate_fft(256, 256, 0, 4)
|
|
||||||
instantiate_fft(512, 512, 1, 4)
|
instantiate_fft(512, 512, 1, 4)
|
||||||
instantiate_fft(1024, 1024, 0, 5)
|
instantiate_fft(1024, 1024, 0, 5)
|
||||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||||
// TODO: implement 4 step FFT for larger n.
|
// TODO: implement 4 step FFT for larger n.
|
||||||
instantiate_fft(2048, 2048, 1, 5)
|
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||||
|
@ -24,7 +24,6 @@ METAL_FUNC void gather_impl(
|
|||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
|
||||||
auto ind_idx = index.x;
|
auto ind_idx = index.x;
|
||||||
auto ind_offset = index.y;
|
auto ind_offset = index.y;
|
||||||
|
|
||||||
@ -43,17 +42,14 @@ METAL_FUNC void gather_impl(
|
|||||||
indices.ndim);
|
indices.ndim);
|
||||||
}
|
}
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
|
||||||
src_idx += idx_val * src_strides[ax];
|
src_idx += idx_val * src_strides[ax];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src_offset = elem_to_loc(
|
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
|
||||||
|
|
||||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||||
out[out_idx] = src[src_offset + src_idx];
|
out[out_idx] = src[src_offset + src_idx];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||||
@ -69,15 +65,10 @@ template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
|||||||
const constant int* idx_shapes [[buffer(7)]], \
|
const constant int* idx_shapes [[buffer(7)]], \
|
||||||
const constant size_t* idx_strides [[buffer(8)]], \
|
const constant size_t* idx_strides [[buffer(8)]], \
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
const constant int& idx_ndim [[buffer(9)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) { \
|
uint2 grid_dim [[threads_per_grid]]) { \
|
||||||
\
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
Indices<IdxT, NIDX> idxs{ \
|
||||||
{{IDX_ARR()}}, \
|
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||||
idx_shapes, \
|
|
||||||
idx_strides, \
|
|
||||||
idx_ndim}; \
|
|
||||||
\
|
\
|
||||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||||
src, \
|
src, \
|
||||||
@ -94,16 +85,8 @@ template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
|||||||
|
|
||||||
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
||||||
|
|
||||||
make_gather(0)
|
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
||||||
make_gather(1)
|
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
||||||
make_gather(2)
|
|
||||||
make_gather(3)
|
|
||||||
make_gather(4)
|
|
||||||
make_gather(5)
|
|
||||||
make_gather(6)
|
|
||||||
make_gather(7)
|
|
||||||
make_gather(8)
|
|
||||||
make_gather(9)
|
|
||||||
make_gather(10)
|
make_gather(10)
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
@ -111,8 +94,8 @@ make_gather(10)
|
|||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
||||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
gather<src_t, idx_t, nidx, nd>( \
|
||||||
const device src_t* src [[buffer(0)]], \
|
const device src_t* src [[buffer(0)]], \
|
||||||
device src_t* out [[buffer(1)]], \
|
device src_t* out [[buffer(1)]], \
|
||||||
const constant int* src_shape [[buffer(2)]], \
|
const constant int* src_shape [[buffer(2)]], \
|
||||||
@ -123,13 +106,14 @@ template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
|||||||
const constant int* idx_shapes [[buffer(7)]], \
|
const constant int* idx_shapes [[buffer(7)]], \
|
||||||
const constant size_t* idx_strides [[buffer(8)]], \
|
const constant size_t* idx_strides [[buffer(8)]], \
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
const constant int& idx_ndim [[buffer(9)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]);
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||||
@ -148,8 +132,9 @@ instantiate_gather4("int32", int32_t, bool, 0)
|
|||||||
instantiate_gather4("int64", int64_t, bool, 0)
|
instantiate_gather4("int64", int64_t, bool, 0)
|
||||||
instantiate_gather4("float16", half, bool, 0)
|
instantiate_gather4("float16", half, bool, 0)
|
||||||
instantiate_gather4("float32", float, bool, 0)
|
instantiate_gather4("float32", float, bool, 0)
|
||||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather3(name, src_type, ind_type) \
|
#define instantiate_gather3(name, src_type, ind_type) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||||
@ -160,8 +145,9 @@ instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
|||||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 10)
|
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather(name, src_type) \
|
#define instantiate_gather(name, src_type) \
|
||||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||||
@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
|
|||||||
instantiate_gather(int64, int64_t)
|
instantiate_gather(int64, int64_t)
|
||||||
instantiate_gather(float16, half)
|
instantiate_gather(float16, half)
|
||||||
instantiate_gather(float32, float)
|
instantiate_gather(float32, float)
|
||||||
instantiate_gather(bfloat16, bfloat16_t)
|
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
@ -25,7 +25,6 @@ template <
|
|||||||
const int TN, /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
struct GEMVKernel {
|
struct GEMVKernel {
|
||||||
|
|
||||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||||
|
|
||||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||||
@ -35,15 +34,20 @@ struct GEMVKernel {
|
|||||||
//
|
//
|
||||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
// and the corresponding scalar from the vector
|
// and the corresponding scalar from the vector
|
||||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
// the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows
|
||||||
// These are then summed up across the threadgroup
|
// These are then summed up across the threadgroup
|
||||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
//
|
//
|
||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// matrix
|
||||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||||
@ -64,7 +68,6 @@ struct GEMVKernel {
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
// Appease compiler
|
// Appease compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
@ -91,14 +94,12 @@ struct GEMVKernel {
|
|||||||
|
|
||||||
// Loop over in_vec in blocks of BN * TN
|
// Loop over in_vec in blocks of BN * TN
|
||||||
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Prefetch in_vector for threadgroup use
|
// Prefetch in_vector for threadgroup use
|
||||||
if (simd_gid == 0) {
|
if (simd_gid == 0) {
|
||||||
// Main load loop
|
// Main load loop
|
||||||
if (bn + TN <= in_vec_size) {
|
if (bn + TN <= in_vec_size) {
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
in_vec_block[tn] = in_vec[bn + tn];
|
in_vec_block[tn] = in_vec[bn + tn];
|
||||||
@ -110,7 +111,6 @@ struct GEMVKernel {
|
|||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,7 +125,6 @@ struct GEMVKernel {
|
|||||||
// Per thread work loop
|
// Per thread work loop
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
|
||||||
// Load for the row
|
// Load for the row
|
||||||
if (bn + TN <= in_vec_size) {
|
if (bn + TN <= in_vec_size) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@ -136,7 +135,8 @@ struct GEMVKernel {
|
|||||||
} else { // Edgecase
|
} else { // Edgecase
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
int col_idx =
|
||||||
|
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,7 +145,6 @@ struct GEMVKernel {
|
|||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,22 +156,17 @@ struct GEMVKernel {
|
|||||||
|
|
||||||
// Write outputs
|
// Write outputs
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
if (kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
out_vec[out_row + tm] =
|
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||||
static_cast<T>(alpha) * result[tm] +
|
|
||||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||||
} else {
|
} else {
|
||||||
out_vec[out_row + tm] = result[tm];
|
out_vec[out_row + tm] = result[tm];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -187,7 +181,6 @@ template <
|
|||||||
const int TN, /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
struct GEMVTKernel {
|
struct GEMVTKernel {
|
||||||
|
|
||||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||||
// - Every thread works on a block of (TM, TN)
|
// - Every thread works on a block of (TM, TN)
|
||||||
@ -195,18 +188,22 @@ struct GEMVTKernel {
|
|||||||
//
|
//
|
||||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
// and the corresponding scalar from the vector
|
// and the corresponding scalar from the vector
|
||||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
// the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows
|
||||||
// These are then summed up across the threadgroup
|
// These are then summed up across the threadgroup
|
||||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
//
|
//
|
||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// matrix
|
||||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||||
|
|
||||||
static METAL_FUNC void run(
|
static METAL_FUNC void run(
|
||||||
@ -225,7 +222,6 @@ struct GEMVTKernel {
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
// Appease compiler
|
// Appease compiler
|
||||||
(void)simd_gid;
|
(void)simd_gid;
|
||||||
(void)simd_lid;
|
(void)simd_lid;
|
||||||
@ -243,7 +239,6 @@ struct GEMVTKernel {
|
|||||||
|
|
||||||
// Edgecase handling
|
// Edgecase handling
|
||||||
if (out_col < out_vec_size) {
|
if (out_col < out_vec_size) {
|
||||||
|
|
||||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||||
|
|
||||||
// Per thread accumulation main loop
|
// Per thread accumulation main loop
|
||||||
@ -254,7 +249,6 @@ struct GEMVTKernel {
|
|||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
if (bm + TM <= in_vec_size) {
|
if (bm + TM <= in_vec_size) {
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
v_coeff[tm] = in_vec[bm + tm];
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
@ -280,11 +274,9 @@ struct GEMVTKernel {
|
|||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Threadgroup collection
|
// Threadgroup collection
|
||||||
@ -298,10 +290,8 @@ struct GEMVTKernel {
|
|||||||
|
|
||||||
// Threadgroup accumulation and writing out results
|
// Threadgroup accumulation and writing out results
|
||||||
if (lid.y == 0 && out_col < out_vec_size) {
|
if (lid.y == 0 && out_col < out_vec_size) {
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 1; i < BM; i++) {
|
for (int i = 1; i < BM; i++) {
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j = 0; j < TN; j++) {
|
for (int j = 0; j < TN; j++) {
|
||||||
result[j] += tgp_results[i * TN + j];
|
result[j] += tgp_results[i * TN + j];
|
||||||
@ -310,10 +300,8 @@ struct GEMVTKernel {
|
|||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j = 0; j < TN; j++) {
|
for (int j = 0; j < TN; j++) {
|
||||||
|
|
||||||
if (kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
out_vec[out_col + j] =
|
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
|
||||||
static_cast<T>(alpha) * result[j] +
|
|
||||||
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
||||||
} else {
|
} else {
|
||||||
out_vec[out_col + j] = result[j];
|
out_vec[out_col + j] = result[j];
|
||||||
@ -355,7 +343,6 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
@ -394,15 +381,13 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid,
|
lid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid
|
simd_lid);
|
||||||
);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||||
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||||
|
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||||
const device itype* mat [[buffer(0)]], \
|
const device itype* mat [[buffer(0)]], \
|
||||||
const device itype* in_vec [[buffer(1)]], \
|
const device itype* in_vec [[buffer(1)]], \
|
||||||
const device itype* bias [[buffer(2)]], \
|
const device itype* bias [[buffer(2)]], \
|
||||||
@ -430,9 +415,8 @@ template <
|
|||||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||||
|
|
||||||
#define instantiate_gemv_blocks(name, itype) \
|
#define instantiate_gemv_blocks(name, itype) \
|
||||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
|
||||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
|
||||||
|
|
||||||
instantiate_gemv_blocks(float32, float);
|
instantiate_gemv_blocks(float32, float);
|
||||||
instantiate_gemv_blocks(float16, half);
|
instantiate_gemv_blocks(float16, half);
|
||||||
@ -470,7 +454,6 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
@ -509,14 +492,13 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid,
|
lid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid
|
simd_lid);
|
||||||
);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||||
|
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||||
const device itype* mat [[buffer(0)]], \
|
const device itype* mat [[buffer(0)]], \
|
||||||
const device itype* in_vec [[buffer(1)]], \
|
const device itype* in_vec [[buffer(1)]], \
|
||||||
const device itype* bias [[buffer(2)]], \
|
const device itype* bias [[buffer(2)]], \
|
||||||
@ -537,20 +519,23 @@ template <
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemv_t_blocks(name, itype) \
|
#define instantiate_gemv_t_blocks(name, itype) \
|
||||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemv_t_blocks(float32, float);
|
instantiate_gemv_t_blocks(float32, float);
|
||||||
instantiate_gemv_t_blocks(float16, half);
|
instantiate_gemv_t_blocks(float16, half);
|
||||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
out[i] =
|
||||||
|
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (x[r + i] - mean) * normalizer;
|
float xi = (x[r + i] - mean) * normalizer;
|
||||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
out[r + i] =
|
||||||
|
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((r + lid * N_READS + i) < axis_size) {
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
float xi = (x[r + i] - mean) * normalizer;
|
float xi = (x[r + i] - mean) * normalizer;
|
||||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
|
||||||
|
b[b_stride * (i + r)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -323,7 +326,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
gx[i] = static_cast<T>(
|
||||||
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
}
|
}
|
||||||
@ -331,7 +335,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
gx[i] = static_cast<T>(
|
||||||
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
}
|
}
|
||||||
@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float xi = (x[i + r] - mean) * normalizer;
|
float xi = (x[i + r] - mean) * normalizer;
|
||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
gx[i + r] = static_cast<T>(
|
||||||
xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float xi = (x[i + r] - mean) * normalizer;
|
float xi = (x[i + r] - mean) * normalizer;
|
||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
gx[i + r] = static_cast<T>(
|
||||||
xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
instantiate_layer_norm(float32, float)
|
instantiate_layer_norm(float32, float)
|
||||||
instantiate_layer_norm(float16, half)
|
instantiate_layer_norm(float16, half)
|
||||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
@ -15,10 +15,11 @@ using namespace metal;
|
|||||||
|
|
||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@ -54,7 +55,9 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
|||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@ -98,29 +101,36 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
inline U qdot(
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (w[i] & 0x03)
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (ws[i] & 0x000f)
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
inline U qdot_safe(
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum,
|
||||||
|
int N) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (w[i] & 0x03)
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (ws[i] & 0x000f)
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,8 +188,11 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
inline void
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||||
@ -202,11 +223,18 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N, int bits>
|
template <typename U, int N, int bits>
|
||||||
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
inline void
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
|
U s[4] = {
|
||||||
|
scale,
|
||||||
|
scale / static_cast<U>(4.0f),
|
||||||
|
scale / static_cast<U>(16.0f),
|
||||||
|
scale / static_cast<U>(64.0f)};
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
||||||
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
||||||
@ -217,7 +245,11 @@ inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U*
|
|||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
|
U s[4] = {
|
||||||
|
scale,
|
||||||
|
scale / static_cast<U>(16.0f),
|
||||||
|
scale / static_cast<U>(256.0f),
|
||||||
|
scale / static_cast<U>(4096.0f)};
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||||
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||||
@ -243,13 +275,20 @@ template <
|
|||||||
short group_size,
|
short group_size,
|
||||||
short bits>
|
short bits>
|
||||||
struct QuantizedBlockLoader {
|
struct QuantizedBlockLoader {
|
||||||
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
|
static_assert(
|
||||||
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
|
BCOLS <= group_size,
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
"The group size should be larger than the columns");
|
||||||
|
static_assert(
|
||||||
|
group_size % BCOLS == 0,
|
||||||
|
"The group size should be divisible by the columns");
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||||
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
MLX_MTL_CONST short n_reads =
|
||||||
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||||
|
|
||||||
const int src_ld;
|
const int src_ld;
|
||||||
@ -275,7 +314,8 @@ struct QuantizedBlockLoader {
|
|||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
: src_ld(src_ld_),
|
: src_ld(src_ld_),
|
||||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
tile_stride(
|
||||||
|
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
||||||
group_step_cnt(0),
|
group_step_cnt(0),
|
||||||
group_stride(BROWS * src_ld / group_size),
|
group_stride(BROWS * src_ld / group_size),
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
@ -294,7 +334,8 @@ struct QuantizedBlockLoader {
|
|||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,7 +361,8 @@ struct QuantizedBlockLoader {
|
|||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,7 +399,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
@ -373,7 +414,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
@ -384,7 +426,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -407,7 +450,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
@ -420,7 +462,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int packs_per_thread = 1;
|
constexpr int packs_per_thread = 1;
|
||||||
@ -437,7 +478,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||||
|
|
||||||
if (out_row >= out_vec_size) {
|
if (out_row >= out_vec_size) {
|
||||||
@ -458,13 +500,15 @@ template <typename T, const int group_size, const int bits>
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -472,11 +516,16 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
const int remaining = clamp(
|
||||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||||
|
0,
|
||||||
|
values_per_thread);
|
||||||
|
U sum =
|
||||||
|
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -506,13 +555,15 @@ template <typename T, const int group_size, const int bits>
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -520,17 +571,23 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
const int remaining = clamp(
|
||||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||||
|
0,
|
||||||
|
values_per_thread);
|
||||||
|
U sum =
|
||||||
|
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||||
|
wl, x_thread, s, b, sum, remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
@ -542,7 +599,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void qvm(
|
[[kernel]] void qvm(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
@ -555,7 +611,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 8;
|
constexpr int num_simdgroups = 8;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int blocksize = SIMD_SIZE;
|
constexpr int blocksize = SIMD_SIZE;
|
||||||
@ -590,7 +645,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||||
|
|
||||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
qouter<U, pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
}
|
}
|
||||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||||
x_local = x[i + simd_lid];
|
x_local = x[i + simd_lid];
|
||||||
@ -603,7 +659,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
bias = 0;
|
bias = 0;
|
||||||
w_local = 0;
|
w_local = 0;
|
||||||
}
|
}
|
||||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
qouter<U, pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
// Accumulate in the simdgroup
|
// Accumulate in the simdgroup
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@ -620,8 +677,14 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N>
|
||||||
[[kernel]] void qmm_t(
|
[[kernel]] void qmm_t(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -635,7 +698,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -647,9 +709,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
using mma_t = mlx::steel::
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||||
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
|
using loader_x_t =
|
||||||
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
BK_padded,
|
||||||
|
1,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
@ -728,8 +800,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits>
|
||||||
[[kernel]] void qmm_n(
|
[[kernel]] void qmm_n(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -743,7 +820,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -756,9 +832,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
using mma_t = mlx::steel::
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
||||||
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
|
using loader_x_t = mlx::steel::
|
||||||
|
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
BK,
|
||||||
|
BN,
|
||||||
|
BN_padded,
|
||||||
|
0,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
@ -847,10 +933,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
|
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
"_fast")]] [[kernel]] void \
|
||||||
|
qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||||
const device uint32_t* w [[buffer(0)]], \
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
const device itype* scales [[buffer(1)]], \
|
const device itype* scales [[buffer(1)]], \
|
||||||
const device itype* biases [[buffer(2)]], \
|
const device itype* biases [[buffer(2)]], \
|
||||||
@ -862,11 +948,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
|
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmv_fast_types(128, 2, 1)
|
instantiate_qmv_fast_types(128, 2, 1)
|
||||||
instantiate_qmv_fast_types(128, 4, 2)
|
instantiate_qmv_fast_types(128, 4, 2)
|
||||||
instantiate_qmv_fast_types(128, 8, 2)
|
instantiate_qmv_fast_types(128, 8, 2)
|
||||||
@ -875,11 +963,12 @@ instantiate_qmv_fast_types( 64, 4, 2)
|
|||||||
instantiate_qmv_fast_types( 64, 8, 2)
|
instantiate_qmv_fast_types( 64, 8, 2)
|
||||||
instantiate_qmv_fast_types( 32, 2, 1)
|
instantiate_qmv_fast_types( 32, 2, 1)
|
||||||
instantiate_qmv_fast_types( 32, 4, 2)
|
instantiate_qmv_fast_types( 32, 4, 2)
|
||||||
instantiate_qmv_fast_types( 32, 8, 2)
|
instantiate_qmv_fast_types( 32, 8, 2) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmv_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qmv<itype, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qmv<itype, group_size, bits>( \
|
||||||
const device uint32_t* w [[buffer(0)]], \
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
const device itype* scales [[buffer(1)]], \
|
const device itype* scales [[buffer(1)]], \
|
||||||
const device itype* biases [[buffer(2)]], \
|
const device itype* biases [[buffer(2)]], \
|
||||||
@ -891,11 +980,13 @@ instantiate_qmv_fast_types( 32, 8, 2)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmv_types(group_size, bits) \
|
#define instantiate_qmv_types(group_size, bits) \
|
||||||
instantiate_qmv(float32, float, group_size, bits) \
|
instantiate_qmv(float32, float, group_size, bits) \
|
||||||
instantiate_qmv(float16, half, group_size, bits) \
|
instantiate_qmv(float16, half, group_size, bits) \
|
||||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmv_types(128, 2)
|
instantiate_qmv_types(128, 2)
|
||||||
instantiate_qmv_types(128, 4)
|
instantiate_qmv_types(128, 4)
|
||||||
instantiate_qmv_types(128, 8)
|
instantiate_qmv_types(128, 8)
|
||||||
@ -904,11 +995,12 @@ instantiate_qmv_types( 64, 4)
|
|||||||
instantiate_qmv_types( 64, 8)
|
instantiate_qmv_types( 64, 8)
|
||||||
instantiate_qmv_types( 32, 2)
|
instantiate_qmv_types( 32, 2)
|
||||||
instantiate_qmv_types( 32, 4)
|
instantiate_qmv_types( 32, 4)
|
||||||
instantiate_qmv_types( 32, 8)
|
instantiate_qmv_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qvm_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qvm<itype, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -920,11 +1012,13 @@ instantiate_qmv_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qvm_types(group_size, bits) \
|
#define instantiate_qvm_types(group_size, bits) \
|
||||||
instantiate_qvm(float32, float, group_size, bits) \
|
instantiate_qvm(float32, float, group_size, bits) \
|
||||||
instantiate_qvm(float16, half, group_size, bits) \
|
instantiate_qvm(float16, half, group_size, bits) \
|
||||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qvm_types(128, 2)
|
instantiate_qvm_types(128, 2)
|
||||||
instantiate_qvm_types(128, 4)
|
instantiate_qvm_types(128, 4)
|
||||||
instantiate_qvm_types(128, 8)
|
instantiate_qvm_types(128, 8)
|
||||||
@ -933,11 +1027,12 @@ instantiate_qvm_types( 64, 4)
|
|||||||
instantiate_qvm_types( 64, 8)
|
instantiate_qvm_types( 64, 8)
|
||||||
instantiate_qvm_types( 32, 2)
|
instantiate_qvm_types( 32, 2)
|
||||||
instantiate_qvm_types( 32, 4)
|
instantiate_qvm_types( 32, 4)
|
||||||
instantiate_qvm_types( 32, 8)
|
instantiate_qvm_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
"_alN_" #aligned_N)]] [[kernel]] void \
|
||||||
|
qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -951,14 +1046,16 @@ instantiate_qvm_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmm_t_types(group_size, bits) \
|
#define instantiate_qmm_t_types(group_size, bits) \
|
||||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
|
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmm_t_types(128, 2)
|
instantiate_qmm_t_types(128, 2)
|
||||||
instantiate_qmm_t_types(128, 4)
|
instantiate_qmm_t_types(128, 4)
|
||||||
instantiate_qmm_t_types(128, 8)
|
instantiate_qmm_t_types(128, 8)
|
||||||
@ -967,11 +1064,12 @@ instantiate_qmm_t_types( 64, 4)
|
|||||||
instantiate_qmm_t_types( 64, 8)
|
instantiate_qmm_t_types( 64, 8)
|
||||||
instantiate_qmm_t_types( 32, 2)
|
instantiate_qmm_t_types( 32, 2)
|
||||||
instantiate_qmm_t_types( 32, 4)
|
instantiate_qmm_t_types( 32, 4)
|
||||||
instantiate_qmm_t_types( 32, 8)
|
instantiate_qmm_t_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmm_n_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -985,11 +1083,13 @@ instantiate_qmm_t_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmm_n_types(group_size, bits) \
|
#define instantiate_qmm_n_types(group_size, bits) \
|
||||||
instantiate_qmm_n(float32, float, group_size, bits) \
|
instantiate_qmm_n(float32, float, group_size, bits) \
|
||||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmm_n_types(128, 2)
|
instantiate_qmm_n_types(128, 2)
|
||||||
instantiate_qmm_n_types(128, 4)
|
instantiate_qmm_n_types(128, 4)
|
||||||
instantiate_qmm_n_types(128, 8)
|
instantiate_qmm_n_types(128, 8)
|
||||||
@ -998,4 +1098,4 @@ instantiate_qmm_n_types( 64, 4)
|
|||||||
instantiate_qmm_n_types( 64, 8)
|
instantiate_qmm_n_types( 64, 8)
|
||||||
instantiate_qmm_n_types( 32, 2)
|
instantiate_qmm_n_types( 32, 2)
|
||||||
instantiate_qmm_n_types( 32, 4)
|
instantiate_qmm_n_types( 32, 4)
|
||||||
instantiate_qmm_n_types( 32, 8)
|
instantiate_qmm_n_types( 32, 8) // clang-format on
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
|
|
||||||
static constexpr constant uint32_t rotations[2][4] = {
|
static constexpr constant uint32_t rotations[2][4] = {
|
||||||
{13, 15, 26, 6},
|
{13, 15, 26, 6},
|
||||||
{17, 29, 16, 24}
|
{17, 29, 16, 24}};
|
||||||
};
|
|
||||||
|
|
||||||
union rbits {
|
union rbits {
|
||||||
uint2 val;
|
uint2 val;
|
||||||
@ -13,7 +12,6 @@ union rbits {
|
|||||||
};
|
};
|
||||||
|
|
||||||
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||||
|
|
||||||
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||||
|
|
||||||
rbits v;
|
rbits v;
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -60,7 +60,6 @@ METAL_FUNC U per_thread_all_reduce(
|
|||||||
// All reduce kernel
|
// All reduce kernel
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||||
// complete the reduction in two steps of simd-level reductions.
|
// complete the reduction in two steps of simd-level reductions.
|
||||||
@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
U total_val =
|
||||||
|
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
// Reduction within simd group
|
// Reduction within simd group
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
U total_val =
|
||||||
|
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
// Reduction within simd group (simd_add isn't supported for uint64/int64
|
||||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
// types)
|
||||||
|
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||||
|
lane_offset /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
}
|
}
|
||||||
// Write simd group reduction results to local memory
|
// Write simd group reduction results to local memory
|
||||||
@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Reduction of simdgroup reduction results within threadgroup.
|
// Reduction of simdgroup reduction results within threadgroup.
|
||||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||||
|
lane_offset /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,8 +141,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||||
template [[host_name("all_reduce_" #name)]] \
|
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
all_reduce<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const device size_t& in_size [[buffer(2)]], \
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
@ -152,8 +154,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
all_reduce_no_atomics<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const device size_t& in_size [[buffer(2)]], \
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
@ -175,6 +177,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
|||||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||||
|
|
||||||
// special case bool with larger output type
|
// special case bool with larger output type
|
||||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* non_col_strides [[buffer(10)]],
|
const constant size_t* non_col_strides [[buffer(10)]],
|
||||||
const constant int& non_col_ndim [[buffer(11)]],
|
const constant int& non_col_ndim [[buffer(11)]],
|
||||||
uint tid [[thread_position_in_grid]]) {
|
uint tid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)out_size;
|
(void)out_size;
|
||||||
|
|
||||||
@ -41,7 +40,8 @@ template <typename T, typename U, typename Op>
|
|||||||
ndim - non_col_ndim);
|
ndim - non_col_ndim);
|
||||||
|
|
||||||
for (uint i = 0; i < non_col_reductions; i++) {
|
for (uint i = 0; i < non_col_reductions; i++) {
|
||||||
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
size_t in_idx =
|
||||||
|
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||||
|
|
||||||
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||||
U val = static_cast<U>(in[in_idx]);
|
U val = static_cast<U>(in[in_idx]);
|
||||||
@ -53,8 +53,8 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_small_" #name)]] \
|
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void col_reduce_small<itype, otype, op>( \
|
col_reduce_small<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -125,12 +125,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 lsize [[threads_per_threadgroup]]) {
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
auto in_idx = elem_to_loc(
|
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||||
out_idx + tid.z * out_size,
|
|
||||||
shape,
|
|
||||||
strides,
|
|
||||||
ndim
|
|
||||||
);
|
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
if (out_idx < out_size) {
|
if (out_idx < out_size) {
|
||||||
@ -144,7 +139,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
lid.xy,
|
lid.xy,
|
||||||
lsize.xy);
|
lsize.xy);
|
||||||
|
|
||||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
// Write out reduction results generated by threadgroups working on specific
|
||||||
|
// output element, contiguously.
|
||||||
if (lid.y == 0) {
|
if (lid.y == 0) {
|
||||||
op.atomic_update(out, val, out_idx);
|
op.atomic_update(out, val, out_idx);
|
||||||
}
|
}
|
||||||
@ -168,12 +164,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
uint3 lsize [[threads_per_threadgroup]],
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
uint3 gsize [[threads_per_grid]]) {
|
uint3 gsize [[threads_per_grid]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
auto in_idx = elem_to_loc(
|
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||||
out_idx + tid.z * out_size,
|
|
||||||
shape,
|
|
||||||
strides,
|
|
||||||
ndim
|
|
||||||
);
|
|
||||||
|
|
||||||
if (out_idx < out_size) {
|
if (out_idx < out_size) {
|
||||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
@ -186,7 +177,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
lid.xy,
|
lid.xy,
|
||||||
lsize.xy);
|
lsize.xy);
|
||||||
|
|
||||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
// Write out reduction results generated by threadgroups working on specific
|
||||||
|
// output element, contiguously.
|
||||||
if (lid.y == 0) {
|
if (lid.y == 0) {
|
||||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||||
@ -196,8 +188,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_general_" #name)]] \
|
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
col_reduce_general<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -212,8 +204,9 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
uint3 lsize [[threads_per_threadgroup]]);
|
uint3 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
template \
|
||||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||||
|
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -233,14 +226,17 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
|||||||
|
|
||||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -18,16 +18,15 @@ template <typename T, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_init_reduce(name, otype, op) \
|
#define instantiate_init_reduce(name, otype, op) \
|
||||||
template [[host_name("i" #name)]] \
|
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||||
[[kernel]] void init_reduce<otype, op>( \
|
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||||
device otype *out [[buffer(1)]], \
|
|
||||||
uint tid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_init_reduce(name##tname, type, op<type>)
|
instantiate_init_reduce(name##tname, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
instantiate_init_reduce(andbool_, bool, And)
|
instantiate_init_reduce(andbool_, bool, And)
|
||||||
instantiate_init_reduce(orbool_, bool, Or)
|
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -22,7 +22,6 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* strides [[buffer(6)]],
|
const constant size_t* strides [[buffer(6)]],
|
||||||
const constant int& ndim [[buffer(7)]],
|
const constant int& ndim [[buffer(7)]],
|
||||||
uint lid [[thread_position_in_grid]]) {
|
uint lid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = lid;
|
uint out_idx = lid;
|
||||||
@ -60,7 +59,6 @@ template <typename T, typename U, typename Op>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||||
@ -81,24 +79,22 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
else if (short(non_row_reductions) >= 32) {
|
else if (short(non_row_reductions) >= 32) {
|
||||||
|
|
||||||
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
||||||
|
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T* in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for (short i = 0; i < short(reduction_size); i++) {
|
for (short i = 0; i < short(reduction_size); i++) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
|
const short n_reductions =
|
||||||
const short n_reductions = short(reduction_size) * short(non_row_reductions);
|
short(reduction_size) * short(non_row_reductions);
|
||||||
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
|
const short reductions_per_thread =
|
||||||
|
(n_reductions + simd_size - 1) / simd_size;
|
||||||
|
|
||||||
const short r_st = simd_lane_id / reductions_per_thread;
|
const short r_st = simd_lane_id / reductions_per_thread;
|
||||||
const short r_ed = short(non_row_reductions);
|
const short r_ed = short(non_row_reductions);
|
||||||
@ -110,20 +106,16 @@ template <typename T, typename U, typename Op>
|
|||||||
|
|
||||||
if (r_st < r_jump) {
|
if (r_st < r_jump) {
|
||||||
for (short r = r_st; r < r_ed; r += r_jump) {
|
for (short r = r_st; r < r_ed; r += r_jump) {
|
||||||
|
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T* in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for (short i = i_st; i < i_ed; i += i_jump) {
|
for (short i = i_st; i < i_ed; i += i_jump) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
@ -132,8 +124,8 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||||
template[[host_name("row_reduce_general_small_" #name)]] \
|
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
|
row_reduce_general_small<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -143,8 +135,8 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const constant int& ndim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
uint lid [[thread_position_in_grid]]); \
|
uint lid [[thread_position_in_grid]]); \
|
||||||
template[[host_name("row_reduce_general_med_" #name)]] \
|
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
|
row_reduce_general_med<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -233,13 +225,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)non_row_reductions;
|
(void)non_row_reductions;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
reduction_size,
|
||||||
|
out_size,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
ndim,
|
||||||
|
lsize.x,
|
||||||
|
lid.x,
|
||||||
|
tid.xy);
|
||||||
|
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
@ -278,13 +278,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)non_row_reductions;
|
(void)non_row_reductions;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
reduction_size,
|
||||||
|
out_size,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
ndim,
|
||||||
|
lsize.x,
|
||||||
|
lid.x,
|
||||||
|
tid.xy);
|
||||||
|
|
||||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||||
@ -312,9 +320,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||||
template [[host_name("row_reduce_general_" #name)]] \
|
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
row_reduce_general<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -331,9 +339,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
@ -350,7 +358,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -361,11 +368,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
|
|
||||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
||||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
||||||
|
|
||||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * axis_size + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
gx[i] = static_cast<T>(
|
||||||
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
gx[i] = static_cast<T>(
|
||||||
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[w_stride * (i + r)];
|
float wi = w[w_stride * (i + r)];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
|
|
||||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
gx[i + r] =
|
||||||
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[w_stride * (i + r)];
|
float wi = w[w_stride * (i + r)];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
|
|
||||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
gx[i + r] =
|
||||||
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
instantiate_rms(float32, float)
|
instantiate_rms(float32, float)
|
||||||
instantiate_rms(float16, half)
|
instantiate_rms(float16, half)
|
||||||
instantiate_rms(bfloat16, bfloat16_t)
|
instantiate_rms(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint in_index_1, in_index_2;
|
uint in_index_1, in_index_2;
|
||||||
uint out_index_1, out_index_2;
|
uint out_index_1, out_index_2;
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + 1;
|
out_index_2 = out_index_1 + 1;
|
||||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 =
|
||||||
|
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
in_index_2 = in_index_1 + strides[2];
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||||
@ -57,8 +60,8 @@ template <typename T, bool traditional, bool forward>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_rope(name, type, traditional, forward) \
|
#define instantiate_rope(name, type, traditional, forward) \
|
||||||
template [[host_name("rope_" #name)]] \
|
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void rope<type, traditional, forward>( \
|
rope<type, traditional, forward>( \
|
||||||
const device type* in [[buffer(0)]], \
|
const device type* in [[buffer(0)]], \
|
||||||
device type* out [[buffer(1)]], \
|
device type* out [[buffer(1)]], \
|
||||||
constant const size_t strides[3], \
|
constant const size_t strides[3], \
|
||||||
@ -69,6 +72,7 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint3 pos [[thread_position_in_grid]], \
|
uint3 pos [[thread_position_in_grid]], \
|
||||||
uint3 grid [[threads_per_grid]]);
|
uint3 grid [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_rope(traditional_float16, half, true, true)
|
instantiate_rope(traditional_float16, half, true, true)
|
||||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||||
instantiate_rope(traditional_float32, float, true, true)
|
instantiate_rope(traditional_float32, float, true, true)
|
||||||
@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
|||||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||||
instantiate_rope(vjp_float16, half, false, false)
|
instantiate_rope(vjp_float16, half, false, false)
|
||||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||||
instantiate_rope(vjp_float32, float, false, false)
|
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
@ -1,11 +1,17 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
|
template <
|
||||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
|
typename T,
|
||||||
|
typename T2,
|
||||||
|
typename T4,
|
||||||
|
uint16_t TILE_SIZE_CONST,
|
||||||
|
uint16_t NSIMDGROUPS>
|
||||||
|
[[kernel]] void fast_inference_sdpa_compute_partials_template(
|
||||||
|
const device T* Q [[buffer(0)]],
|
||||||
const device T* K [[buffer(1)]],
|
const device T* K [[buffer(1)]],
|
||||||
const device T* V [[buffer(2)]],
|
const device T* V [[buffer(2)]],
|
||||||
const device uint64_t& L [[buffer(3)]],
|
const device uint64_t& L [[buffer(3)]],
|
||||||
@ -28,24 +34,31 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||||
}
|
}
|
||||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
|
||||||
|
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
|
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
|
||||||
|
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
|
||||||
|
NSIMDGROUPS;
|
||||||
|
|
||||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (uint i = 0; i < 8; i++) {
|
for (uint i = 0; i < 8; i++) {
|
||||||
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
smemFlush
|
||||||
|
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
|
||||||
|
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// TODO: multiple query sequence length for speculative decoding
|
// TODO: multiple query sequence length for speculative decoding
|
||||||
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
const uint tgroup_query_head_offset =
|
||||||
|
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||||
|
|
||||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||||
|
|
||||||
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
const device T* baseK =
|
||||||
|
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||||
|
|
||||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||||
@ -54,7 +67,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
float threadAccum[ACCUM_PER_GROUP];
|
float threadAccum[ACCUM_PER_GROUP];
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
|
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
|
||||||
|
threadAccumIndex++) {
|
||||||
threadAccum[threadAccumIndex] = -INFINITY;
|
threadAccum[threadAccumIndex] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,14 +76,16 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||||
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
const bool LAST_TILE_ALIGNED =
|
||||||
|
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||||
|
|
||||||
T4 thread_data_x4;
|
T4 thread_data_x4;
|
||||||
T4 thread_data_y4;
|
T4 thread_data_y4;
|
||||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
|
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
|
||||||
|
KROW += NSIMDGROUPS) {
|
||||||
const uint KROW_OFFSET = KROW * DK;
|
const uint KROW_OFFSET = KROW * DK;
|
||||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||||
device T4* keysData = (device T4*)baseKRow;
|
device T4* keysData = (device T4*)baseKRow;
|
||||||
@ -81,9 +97,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
} else {
|
} else {
|
||||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||||
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
const device T* baseKThisHead =
|
||||||
|
K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||||
|
|
||||||
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
|
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
|
||||||
|
KROW += NSIMDGROUPS) {
|
||||||
const uint KROW_OFFSET = KROW * DK;
|
const uint KROW_OFFSET = KROW * DK;
|
||||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||||
device T4* keysData = (device T4*)baseKRow;
|
device T4* keysData = (device T4*)baseKRow;
|
||||||
@ -97,7 +115,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (size_t i = 0; i < P_VEC4; i++) {
|
for (size_t i = 0; i < P_VEC4; i++) {
|
||||||
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
|
thread_data_x4 =
|
||||||
|
T4(threadAccum[4 * i],
|
||||||
|
threadAccum[4 * i + 1],
|
||||||
|
threadAccum[4 * i + 2],
|
||||||
|
threadAccum[4 * i + 3]);
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
thread_data_y4 = simd_sum(thread_data_x4);
|
thread_data_y4 = simd_sum(thread_data_x4);
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
@ -115,11 +137,13 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
float lse = 0.f;
|
float lse = 0.f;
|
||||||
|
|
||||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||||
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
constexpr const size_t ACCUM_ARRAY_LENGTH =
|
||||||
|
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
|
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
|
||||||
|
accum_array_iter++) {
|
||||||
pvals[accum_array_iter] = float4(-INFINITY);
|
pvals[accum_array_iter] = float4(-INFINITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,10 +216,15 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
uint matrix_load_loop_iter = 0;
|
uint matrix_load_loop_iter = 0;
|
||||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||||
|
|
||||||
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
|
for (size_t tile_start = simd_group_id;
|
||||||
|
tile_start < TILE_SIZE_CONST_DIV_8;
|
||||||
|
tile_start += NSIMDGROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> tmp;
|
simdgroup_matrix<T, 8, 8> tmp;
|
||||||
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
ulong simdgroup_matrix_offset =
|
||||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
|
||||||
|
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
ulong2 matrixOrigin =
|
||||||
|
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||||
@ -208,10 +237,12 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
if (TILE_SIZE_CONST == 64) {
|
if (TILE_SIZE_CONST == 64) {
|
||||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||||
uint loop_iter = 0;
|
uint loop_iter = 0;
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
row += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||||
T2 v_local = *(smemV2 + simd_lane_id);
|
T2 v_local = *(smemV2 + simd_lane_id);
|
||||||
@ -220,16 +251,20 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
T row_sum = simd_sum(val);
|
T row_sum = simd_sum(val);
|
||||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||||
|
float(row_sum);
|
||||||
loop_iter++;
|
loop_iter++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TILE_SIZE_CONST > 64) {
|
if (TILE_SIZE_CONST > 64) {
|
||||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
|
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
(TILE_SIZE_CONST + 1) / 128;
|
||||||
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
uint loop_iter = 0;
|
uint loop_iter = 0;
|
||||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
row += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||||
|
|
||||||
T row_sum = 0.f;
|
T row_sum = 0.f;
|
||||||
@ -242,7 +277,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
row_sum = simd_sum(row_sum);
|
row_sum = simd_sum(row_sum);
|
||||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||||
|
float(row_sum);
|
||||||
loop_iter++;
|
loop_iter++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -256,31 +292,46 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
int32_t tile_start;
|
int32_t tile_start;
|
||||||
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
for (tile_start =
|
||||||
|
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
tile_start < MAX_START_ROW;
|
||||||
|
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||||
simdgroup_matrix<T, 8, 8> tmp;
|
simdgroup_matrix<T, 8, 8> tmp;
|
||||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
ulong2 matrixOrigin =
|
||||||
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||||
|
simdgroup_load(
|
||||||
|
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
|
simdgroup_store(
|
||||||
|
tmp,
|
||||||
|
smemV,
|
||||||
|
elemsPerRowSmem,
|
||||||
|
matrixOriginSmem,
|
||||||
|
/* transpose */ false);
|
||||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
};
|
};
|
||||||
|
|
||||||
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
tile_start =
|
||||||
|
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||||
|
|
||||||
const int32_t INT_L = int32_t(L);
|
const int32_t INT_L = int32_t(L);
|
||||||
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
|
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
|
||||||
|
row_index += NSIMDGROUPS) {
|
||||||
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||||
const uint elems_per_row_gmem = DK;
|
const uint elems_per_row_gmem = DK;
|
||||||
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
const uint col_index_v_gmem =
|
||||||
|
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||||
const uint row_index_v_gmem = row_index;
|
const uint row_index_v_gmem = row_index;
|
||||||
|
|
||||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||||
const uint row_index_v_smem = simd_lane_id;
|
const uint row_index_v_smem = simd_lane_id;
|
||||||
|
|
||||||
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
const uint scalar_offset_gmem =
|
||||||
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||||
|
const uint scalar_offset_smem =
|
||||||
|
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||||
smemV[scalar_offset_smem] = vdata;
|
smemV[scalar_offset_smem] = vdata;
|
||||||
smem_col_index += NSIMDGROUPS;
|
smem_col_index += NSIMDGROUPS;
|
||||||
@ -291,9 +342,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
if (TILE_SIZE_CONST == 64) {
|
if (TILE_SIZE_CONST == 64) {
|
||||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
for (size_t smem_row_index = simd_group_id;
|
for (size_t smem_row_index = simd_group_id;
|
||||||
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
|
smem_row_index < ROWS_PER_ITER;
|
||||||
|
smem_row_index += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||||
T2 v_local = *(smemV2 + simd_lane_id);
|
T2 v_local = *(smemV2 + simd_lane_id);
|
||||||
@ -305,22 +358,25 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (TILE_SIZE_CONST > 64) {
|
if (TILE_SIZE_CONST > 64) {
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
uint loop_count = 0;
|
uint loop_count = 0;
|
||||||
for(size_t row_index = simd_group_id;
|
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
|
||||||
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
|
row_index += NSIMDGROUPS) {
|
||||||
T row_sum = 0.f;
|
T row_sum = 0.f;
|
||||||
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
|
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
|
||||||
|
tile_iters++) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||||
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
T4 v_local =
|
||||||
|
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||||
T4 p_local = T4(pvals[tile_iters]);
|
T4 p_local = T4(pvals[tile_iters]);
|
||||||
row_sum += dot(p_local, v_local);
|
row_sum += dot(p_local, v_local);
|
||||||
|
|
||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
row_sum = simd_sum(row_sum);
|
row_sum = simd_sum(row_sum);
|
||||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
|
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
|
||||||
|
float(row_sum);
|
||||||
loop_count++;
|
loop_count++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -332,22 +388,32 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
if (simd_group_id == 0) {
|
if (simd_group_id == 0) {
|
||||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||||
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
device float* oPartialGmem =
|
||||||
|
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||||
oPartialGmemVec4[simd_lane_id] = vals;
|
oPartialGmemVec4[simd_lane_id] = vals;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (simd_group_id == 0 && simd_lane_id == 0) {
|
if (simd_group_id == 0 && simd_lane_id == 0) {
|
||||||
const uint tileIndex = tid.y;
|
const uint tileIndex = tid.y;
|
||||||
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
|
const uint gmem_partial_scalar_offset =
|
||||||
|
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
|
||||||
|
tileIndex;
|
||||||
p_lse[gmem_partial_scalar_offset] = lse;
|
p_lse[gmem_partial_scalar_offset] = lse;
|
||||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
|
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
|
itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||||
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
|
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
|
||||||
|
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
|
||||||
|
fast_inference_sdpa_compute_partials_template< \
|
||||||
|
itype, \
|
||||||
|
itype2, \
|
||||||
|
itype4, \
|
||||||
|
tile_size, \
|
||||||
|
nsimdgroups>( \
|
||||||
const device itype* Q [[buffer(0)]], \
|
const device itype* Q [[buffer(0)]], \
|
||||||
const device itype* K [[buffer(1)]], \
|
const device itype* K [[buffer(1)]], \
|
||||||
const device itype* V [[buffer(2)]], \
|
const device itype* V [[buffer(2)]], \
|
||||||
@ -361,21 +427,55 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]]);
|
uint3 tid [[threadgroup_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
|
||||||
|
itype, itype2, itype4, tile_size) \
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
|
itype, itype2, itype4, tile_size, 4) \
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
|
itype, itype2, itype4, tile_size, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
|
float2,
|
||||||
|
float4,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
|
64);
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
|
float2,
|
||||||
|
float4,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
|
128);
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
|
float2,
|
||||||
|
float4,
|
||||||
|
256);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
float,
|
||||||
|
float2,
|
||||||
|
float4,
|
||||||
|
512);
|
||||||
|
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
64);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
128);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
256);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
512);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fast_inference_sdpa_reduce_tiles_template(
|
void fast_inference_sdpa_reduce_tiles_template(
|
||||||
@ -386,12 +486,13 @@ void fast_inference_sdpa_reduce_tiles_template(
|
|||||||
device T* O [[buffer(4)]],
|
device T* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
constexpr const int DK = 128;
|
constexpr const int DK = 128;
|
||||||
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
const ulong offset_rows =
|
||||||
|
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||||
const device float* p_lse_row = p_lse + offset_rows;
|
const device float* p_lse_row = p_lse + offset_rows;
|
||||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||||
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
|
// reserve some number of registers. this constitutes an assumption on max
|
||||||
|
// value of KV TILES.
|
||||||
constexpr const uint8_t reserve = 128;
|
constexpr const uint8_t reserve = 128;
|
||||||
float p_lse_regs[reserve];
|
float p_lse_regs[reserve];
|
||||||
float p_rowmax_regs[reserve];
|
float p_rowmax_regs[reserve];
|
||||||
@ -411,7 +512,9 @@ void fast_inference_sdpa_reduce_tiles_template(
|
|||||||
denom += weights[i];
|
denom += weights[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
|
const device float* O_partials_with_offset = O_partials +
|
||||||
|
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
|
||||||
|
tid.x * DK * params.KV_TILES;
|
||||||
|
|
||||||
float o_value = 0.f;
|
float o_value = 0.f;
|
||||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||||
@ -423,7 +526,6 @@ void fast_inference_sdpa_reduce_tiles_template(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
kernel void fast_inference_sdpa_reduce_tiles_float(
|
kernel void fast_inference_sdpa_reduce_tiles_float(
|
||||||
const device float* O_partials [[buffer(0)]],
|
const device float* O_partials [[buffer(0)]],
|
||||||
const device float* p_lse [[buffer(1)]],
|
const device float* p_lse [[buffer(1)]],
|
||||||
@ -431,10 +533,9 @@ kernel void fast_inference_sdpa_reduce_tiles_float(
|
|||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||||
device float* O [[buffer(4)]],
|
device float* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]])
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
{
|
fast_inference_sdpa_reduce_tiles_template<float>(
|
||||||
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
|
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||||
O, tid, lid);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void fast_inference_sdpa_reduce_tiles_half(
|
kernel void fast_inference_sdpa_reduce_tiles_half(
|
||||||
@ -444,8 +545,7 @@ kernel void fast_inference_sdpa_reduce_tiles_half(
|
|||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||||
device half* O [[buffer(4)]],
|
device half* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]])
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
{
|
fast_inference_sdpa_reduce_tiles_template<half>(
|
||||||
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
|
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||||
O, tid, lid);
|
|
||||||
}
|
}
|
||||||
|
@ -127,10 +127,16 @@ inline void load_unsafe(U values[N_READS], const device T * input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N_READS, bool reverse>
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
inline void load_safe(
|
||||||
|
U values[N_READS],
|
||||||
|
const device T* input,
|
||||||
|
int start,
|
||||||
|
int total,
|
||||||
|
U init) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
values[N_READS - i - 1] =
|
||||||
|
(start + N_READS - i - 1 < total) ? input[i] : init;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -169,8 +175,13 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N_READS,
|
||||||
|
bool inclusive,
|
||||||
|
bool reverse>
|
||||||
[[kernel]] void contiguous_scan(
|
[[kernel]] void contiguous_scan(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U* out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
@ -195,14 +206,16 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
U values[N_READS];
|
U values[N_READS];
|
||||||
threadgroup U simdgroup_sums[32];
|
threadgroup U simdgroup_sums[32];
|
||||||
|
|
||||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
|
||||||
|
// N_READS*lsize)
|
||||||
// Read block
|
// Read block
|
||||||
// Compute inclusive scan of the block
|
// Compute inclusive scan of the block
|
||||||
// Compute inclusive scan per thread
|
// Compute inclusive scan per thread
|
||||||
// Compute exclusive scan of thread sums in simdgroup
|
// Compute exclusive scan of thread sums in simdgroup
|
||||||
// Write simdgroup sums in SM
|
// Write simdgroup sums in SM
|
||||||
// Compute exclusive scan of simdgroup sums
|
// Compute exclusive scan of simdgroup sums
|
||||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
|
||||||
|
// value
|
||||||
// Write block
|
// Write block
|
||||||
|
|
||||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||||
@ -212,15 +225,22 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
// Read the values
|
// Read the values
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
load_unsafe<T, U, N_READS, reverse>(
|
||||||
|
values, in + axis_size - offset - N_READS);
|
||||||
} else {
|
} else {
|
||||||
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
load_safe<T, U, N_READS, reverse>(
|
||||||
|
values,
|
||||||
|
in + axis_size - offset - N_READS,
|
||||||
|
offset,
|
||||||
|
axis_size,
|
||||||
|
Op::init);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||||
} else {
|
} else {
|
||||||
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
load_safe<T, U, N_READS, reverse>(
|
||||||
|
values, in + offset, offset, axis_size, Op::init);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if (reverse) {
|
if (reverse) {
|
||||||
if (inclusive) {
|
if (inclusive) {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
write_unsafe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - N_READS);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (lid == 0 && offset == 0) {
|
if (lid == 0 && offset == 0) {
|
||||||
out[axis_size - 1] = Op::init;
|
out[axis_size - 1] = Op::init;
|
||||||
}
|
}
|
||||||
if ((offset + N_READS + 1) < axis_size) {
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
write_unsafe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - 1 - N_READS);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values,
|
||||||
|
out + axis_size - offset - 1 - N_READS,
|
||||||
|
offset + 1,
|
||||||
|
axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + offset, offset, axis_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (lid == 0 && offset == 0) {
|
if (lid == 0 && offset == 0) {
|
||||||
@ -284,7 +312,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if ((offset + N_READS + 1) < axis_size) {
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + offset + 1, offset + 1, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -298,7 +327,13 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N_READS,
|
||||||
|
bool inclusive,
|
||||||
|
bool reverse>
|
||||||
[[kernel]] void strided_scan(
|
[[kernel]] void strided_scan(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U* out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
@ -334,14 +369,17 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
// Read in SM
|
// Read in SM
|
||||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
in[offset + index_y * stride + index_x + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
in[offset + index_y * stride + index_x + i];
|
||||||
} else {
|
} else {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
Op::init;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -349,9 +387,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
|
|
||||||
// Read strided into registers
|
// Read strided into registers
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
values[i] =
|
||||||
|
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||||
}
|
}
|
||||||
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||||
|
// simultaneously?
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Perform the scan
|
// Perform the scan
|
||||||
@ -363,7 +403,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
|
|
||||||
// Write to SM
|
// Write to SM
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||||
|
values[i];
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -392,21 +433,24 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
}
|
}
|
||||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
out[offset + index_y * stride + index_x + i] =
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
out[offset + index_y * stride + index_x + i] =
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
#define instantiate_contiguous_scan( \
|
||||||
template [[host_name("contiguous_scan_" #name)]] \
|
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
|
||||||
|
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& axis_size [[buffer(2)]], \
|
const constant size_t& axis_size [[buffer(2)]], \
|
||||||
@ -417,9 +461,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
#define instantiate_strided_scan( \
|
||||||
template [[host_name("strided_scan_" #name)]] \
|
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
template [[host_name("strided_scan_" #name)]] [[kernel]] void \
|
||||||
|
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& axis_size [[buffer(2)]], \
|
const constant size_t& axis_size [[buffer(2)]], \
|
||||||
@ -429,7 +474,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
uint2 lsize [[threads_per_threadgroup]], \
|
uint2 lsize [[threads_per_threadgroup]], \
|
||||||
uint simd_size [[threads_per_simdgroup]]);
|
uint simd_size [[threads_per_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
@ -438,8 +483,9 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||||
@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
|||||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
|
@ -13,7 +13,7 @@ using namespace metal;
|
|||||||
// Scatter kernel
|
// Scatter kernel
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
METAL_FUNC void scatter_1d_index_impl(
|
METAL_FUNC void scatter_1d_index_impl(
|
||||||
const device T* updates [[buffer(1)]],
|
const device T* updates [[buffer(1)]],
|
||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
@ -22,13 +22,11 @@ METAL_FUNC void scatter_1d_index_impl(
|
|||||||
const constant size_t& upd_size [[buffer(5)]],
|
const constant size_t& upd_size [[buffer(5)]],
|
||||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = 0;
|
uint out_idx = 0;
|
||||||
for (int i = 0; i < NIDX; i++) {
|
for (int i = 0; i < NIDX; i++) {
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||||
idx_buffers[i][gid.y], out_shape[i]);
|
|
||||||
out_idx += idx_val * out_strides[i];
|
out_idx += idx_val * out_strides[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,20 +41,11 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
|||||||
const constant int* out_shape [[buffer(3)]], \
|
const constant int* out_shape [[buffer(3)]], \
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
const constant size_t* out_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
const constant size_t& upd_size [[buffer(5)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||||
uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
\
|
|
||||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||||
\
|
\
|
||||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||||
updates, \
|
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
|
||||||
out, \
|
|
||||||
out_shape, \
|
|
||||||
out_strides, \
|
|
||||||
upd_size, \
|
|
||||||
idx_buffers, \
|
|
||||||
gid); \
|
|
||||||
\
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
@ -73,7 +62,6 @@ METAL_FUNC void scatter_impl(
|
|||||||
const constant int* axes [[buffer(10)]],
|
const constant int* axes [[buffer(10)]],
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
auto ind_idx = gid.y;
|
auto ind_idx = gid.y;
|
||||||
auto ind_offset = gid.x;
|
auto ind_offset = gid.x;
|
||||||
@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
|
|||||||
&indices.strides[indices.ndim * i],
|
&indices.strides[indices.ndim * i],
|
||||||
indices.ndim);
|
indices.ndim);
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
|
||||||
out_idx += idx_val * out_strides[ax];
|
out_idx += idx_val * out_strides[ax];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +84,8 @@ METAL_FUNC void scatter_impl(
|
|||||||
out_idx += out_offset;
|
out_idx += out_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
auto upd_idx =
|
||||||
|
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,14 +105,9 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
|||||||
const constant int* idx_shapes [[buffer(11)]], \
|
const constant int* idx_shapes [[buffer(11)]], \
|
||||||
const constant size_t* idx_strides [[buffer(12)]], \
|
const constant size_t* idx_strides [[buffer(12)]], \
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
const constant int& idx_ndim [[buffer(13)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||||
uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
\
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
Indices<IdxT, NIDX> idxs{ \
|
||||||
{{IDX_ARR()}}, \
|
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||||
idx_shapes, \
|
|
||||||
idx_strides, \
|
|
||||||
idx_ndim}; \
|
|
||||||
\
|
\
|
||||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||||
updates, \
|
updates, \
|
||||||
@ -145,25 +128,17 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
|||||||
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
|
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
|
||||||
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
|
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
|
||||||
|
|
||||||
make_scatter(0)
|
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
|
||||||
make_scatter(1)
|
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
|
||||||
make_scatter(2)
|
make_scatter(9) make_scatter(10)
|
||||||
make_scatter(3)
|
|
||||||
make_scatter(4)
|
|
||||||
make_scatter(5)
|
|
||||||
make_scatter(6)
|
|
||||||
make_scatter(7)
|
|
||||||
make_scatter(8)
|
|
||||||
make_scatter(9)
|
|
||||||
make_scatter(10)
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Scatter instantiations
|
// Scatter instantiations
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||||
template [[host_name("scatter" name "_" #nidx)]] \
|
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
|
||||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
scatter<src_t, idx_t, op_t, nidx>( \
|
||||||
const device src_t* updates [[buffer(1)]], \
|
const device src_t* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||||
const constant int* upd_shape [[buffer(3)]], \
|
const constant int* upd_shape [[buffer(3)]], \
|
||||||
@ -177,32 +152,33 @@ template [[host_name("scatter" name "_" #nidx)]] \
|
|||||||
const constant int* idx_shapes [[buffer(11)]], \
|
const constant int* idx_shapes [[buffer(11)]], \
|
||||||
const constant size_t* idx_strides [[buffer(12)]], \
|
const constant size_t* idx_strides [[buffer(12)]], \
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
const constant int& idx_ndim [[buffer(13)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
|
||||||
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||||
const device src_t* updates [[buffer(1)]], \
|
const device src_t* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||||
const constant int* out_shape [[buffer(3)]], \
|
const constant int* out_shape [[buffer(3)]], \
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
const constant size_t* out_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
const constant size_t& upd_size [[buffer(5)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
||||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
// Special case NINDEX=0
|
// Special case NINDEX=0
|
||||||
#define instantiate_scatter_nd0(name, type) \
|
#define instantiate_scatter_nd0(name, type) \
|
||||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||||
@ -213,15 +189,17 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
|||||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter2(name, type, ind_type) \
|
#define instantiate_scatter2(name, type, ind_type) \
|
||||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter(name, type) \
|
#define instantiate_scatter(name, type) \
|
||||||
instantiate_scatter2(#name "bool_", type, bool) \
|
instantiate_scatter2(#name "bool_", type, bool) \
|
||||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||||
@ -231,8 +209,9 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
|||||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||||
instantiate_scatter2(#name "int64", type, int64_t)
|
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
// TODO uint64 and int64 unsupported
|
// TODO uint64 and int64 unsupported
|
||||||
instantiate_scatter_nd0(bool_, bool)
|
instantiate_scatter_nd0(bool_, bool)
|
||||||
instantiate_scatter_nd0(uint8, uint8_t)
|
instantiate_scatter_nd0(uint8, uint8_t)
|
||||||
@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
|
|||||||
instantiate_scatter(int32, int32_t)
|
instantiate_scatter(int32, int32_t)
|
||||||
instantiate_scatter(float16, half)
|
instantiate_scatter(float16, half)
|
||||||
instantiate_scatter(float32, float)
|
instantiate_scatter(float32, float)
|
||||||
instantiate_scatter(bfloat16, bfloat16_t)
|
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
@ -198,7 +198,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_softmax(name, itype) \
|
#define instantiate_softmax(name, itype) \
|
||||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||||
softmax_single_row<itype>( \
|
softmax_single_row<itype>( \
|
||||||
@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_softmax(float32, float)
|
instantiate_softmax(float32, float)
|
||||||
instantiate_softmax(float16, half)
|
instantiate_softmax(float16, half)
|
||||||
instantiate_softmax(bfloat16, bfloat16_t)
|
instantiate_softmax(bfloat16, bfloat16_t)
|
||||||
instantiate_softmax_precise(float16, half)
|
instantiate_softmax_precise(float16, half)
|
||||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
@ -11,7 +11,8 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
// Based on GPU merge sort algorithm at
|
||||||
|
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Thread-level sort
|
// Thread-level sort
|
||||||
@ -43,7 +44,6 @@ struct ThreadSort {
|
|||||||
static METAL_FUNC void sort(
|
static METAL_FUNC void sort(
|
||||||
thread val_t (&vals)[N_PER_THREAD],
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
MLX_MTL_LOOP_UNROLL
|
MLX_MTL_LOOP_UNROLL
|
||||||
@ -56,7 +56,6 @@ struct ThreadSort {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -72,14 +71,14 @@ template <
|
|||||||
short N_PER_THREAD,
|
short N_PER_THREAD,
|
||||||
typename CompareOp>
|
typename CompareOp>
|
||||||
struct BlockMergeSort {
|
struct BlockMergeSort {
|
||||||
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
using thread_sort_t =
|
||||||
|
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||||
static METAL_FUNC int merge_partition(
|
static METAL_FUNC int merge_partition(
|
||||||
const threadgroup val_t* As,
|
const threadgroup val_t* As,
|
||||||
const threadgroup val_t* Bs,
|
const threadgroup val_t* Bs,
|
||||||
short A_sz,
|
short A_sz,
|
||||||
short B_sz,
|
short B_sz,
|
||||||
short sort_md) {
|
short sort_md) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
short A_st = max(0, sort_md - B_sz);
|
short A_st = max(0, sort_md - B_sz);
|
||||||
@ -98,7 +97,6 @@ struct BlockMergeSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return A_ed;
|
return A_ed;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC void merge_step(
|
static METAL_FUNC void merge_step(
|
||||||
@ -110,7 +108,6 @@ struct BlockMergeSort {
|
|||||||
short B_sz,
|
short B_sz,
|
||||||
thread val_t (&vals)[N_PER_THREAD],
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
short a_idx = 0;
|
short a_idx = 0;
|
||||||
short b_idx = 0;
|
short b_idx = 0;
|
||||||
@ -126,7 +123,6 @@ struct BlockMergeSort {
|
|||||||
b_idx += short(pred);
|
b_idx += short(pred);
|
||||||
a_idx += short(!pred);
|
a_idx += short(!pred);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC void sort(
|
static METAL_FUNC void sort(
|
||||||
@ -134,7 +130,6 @@ struct BlockMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||||
int size_sorted_axis,
|
int size_sorted_axis,
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Get thread location
|
// Get thread location
|
||||||
int idx = lid.x * N_PER_THREAD;
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
|
||||||
@ -154,7 +149,8 @@ struct BlockMergeSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Do merges using threadgroup memory
|
// Do merges using threadgroup memory
|
||||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||||
|
merge_threads *= 2) {
|
||||||
// Update threadgroup memory
|
// Update threadgroup memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
@ -189,12 +185,7 @@ struct BlockMergeSort {
|
|||||||
// of size N_PER_THREAD for each merge lane i
|
// of size N_PER_THREAD for each merge lane i
|
||||||
// C = [Ci] is sorted
|
// C = [Ci] is sorted
|
||||||
int sort_md = N_PER_THREAD * merge_lane;
|
int sort_md = N_PER_THREAD * merge_lane;
|
||||||
int partition = merge_partition(
|
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
sort_md);
|
|
||||||
|
|
||||||
As += partition;
|
As += partition;
|
||||||
Bs += sort_md - partition;
|
Bs += sort_md - partition;
|
||||||
@ -202,20 +193,13 @@ struct BlockMergeSort {
|
|||||||
A_sz -= partition;
|
A_sz -= partition;
|
||||||
B_sz -= sort_md - partition;
|
B_sz -= sort_md - partition;
|
||||||
|
|
||||||
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
const threadgroup idx_t* As_idx =
|
||||||
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||||
|
const threadgroup idx_t* Bs_idx =
|
||||||
|
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||||
|
|
||||||
// Merge starting at the partition and store results in thread registers
|
// Merge starting at the partition and store results in thread registers
|
||||||
merge_step(
|
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
As_idx,
|
|
||||||
Bs_idx,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
thread_vals,
|
|
||||||
thread_idxs);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write out to shared memory
|
// Write out to shared memory
|
||||||
@ -263,14 +247,14 @@ struct KernelMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs,
|
threadgroup idx_t* tgp_idxs,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// tid.y tells us the segment index
|
// tid.y tells us the segment index
|
||||||
inp += tid.y * stride_segment_axis;
|
inp += tid.y * stride_segment_axis;
|
||||||
out += tid.y * stride_segment_axis;
|
out += tid.y * stride_segment_axis;
|
||||||
|
|
||||||
// Copy into threadgroup memory
|
// Copy into threadgroup memory
|
||||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||||
|
: val_t(CompareOp::init);
|
||||||
if (ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
tgp_idxs[i] = i;
|
tgp_idxs[i] = i;
|
||||||
}
|
}
|
||||||
@ -308,8 +292,8 @@ template <
|
|||||||
const constant int& stride_segment_axis [[buffer(4)]],
|
const constant int& stride_segment_axis [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
using val_t = typename sort_kernel::val_t;
|
using val_t = typename sort_kernel::val_t;
|
||||||
using idx_t = typename sort_kernel::idx_t;
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
@ -339,7 +323,6 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid);
|
lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constant constexpr const int zero_helper = 0;
|
constant constexpr const int zero_helper = 0;
|
||||||
@ -360,8 +343,8 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(6)]],
|
const device size_t* nc_strides [[buffer(6)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
using val_t = typename sort_kernel::val_t;
|
using val_t = typename sort_kernel::val_t;
|
||||||
using idx_t = typename sort_kernel::idx_t;
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
@ -395,17 +378,17 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid);
|
lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_block_sort( \
|
||||||
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
|
||||||
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||||
const device itype* inp [[buffer(0)]], \
|
const device itype* inp [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
@ -413,8 +396,9 @@ template <
|
|||||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||||
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
"_nc")]] [[kernel]] void \
|
||||||
|
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||||
const device itype* inp [[buffer(0)]], \
|
const device itype* inp [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
@ -426,15 +410,19 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||||
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
instantiate_block_sort( \
|
||||||
|
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||||
|
|
||||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||||
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
instantiate_block_sort( \
|
||||||
|
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_bn(itname, itype) \
|
#define instantiate_block_sort_bn(itname, itype) \
|
||||||
instantiate_block_sort_tn(itname, itype, 128) \
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
instantiate_block_sort_tn(itname, itype, 256) \
|
instantiate_block_sort_tn(itname, itype, 256) \
|
||||||
@ -448,14 +436,14 @@ instantiate_block_sort_bn(int16, int16_t)
|
|||||||
instantiate_block_sort_bn(int32, int32_t)
|
instantiate_block_sort_bn(int32, int32_t)
|
||||||
instantiate_block_sort_bn(float16, half)
|
instantiate_block_sort_bn(float16, half)
|
||||||
instantiate_block_sort_bn(float32, float)
|
instantiate_block_sort_bn(float32, float)
|
||||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_long(itname, itype) \
|
#define instantiate_block_sort_long(itname, itype) \
|
||||||
instantiate_block_sort_tn(itname, itype, 128) \
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
instantiate_block_sort_tn(itname, itype, 256)
|
instantiate_block_sort_tn(itname, itype, 256)
|
||||||
|
|
||||||
instantiate_block_sort_long(uint64, uint64_t)
|
instantiate_block_sort_long(uint64, uint64_t)
|
||||||
instantiate_block_sort_long(int64, int64_t)
|
instantiate_block_sort_long(int64, int64_t) // clang-format on
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Multi block merge sort
|
// Multi block merge sort
|
||||||
@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs,
|
threadgroup idx_t* tgp_idxs,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// tid.y tells us the segment index
|
// tid.y tells us the segment index
|
||||||
int base_idx = tid.x * N_PER_BLOCK;
|
int base_idx = tid.x * N_PER_BLOCK;
|
||||||
|
|
||||||
// Copy into threadgroup memory
|
// Copy into threadgroup memory
|
||||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
int idx = base_idx + i;
|
int idx = base_idx + i;
|
||||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||||
|
: val_t(CompareOp::init);
|
||||||
tgp_idxs[i] = idx;
|
tgp_idxs[i] = idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,7 +511,6 @@ struct KernelMultiBlockMergeSort {
|
|||||||
int A_sz,
|
int A_sz,
|
||||||
int B_sz,
|
int B_sz,
|
||||||
int sort_md) {
|
int sort_md) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
int A_st = max(0, sort_md - B_sz);
|
int A_st = max(0, sort_md - B_sz);
|
||||||
@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return A_ed;
|
return A_ed;
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -563,8 +549,12 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]],
|
const device size_t* nc_strides [[buffer(7)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD>;
|
||||||
|
|
||||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||||
inp += block_idx;
|
inp += block_idx;
|
||||||
@ -592,7 +582,8 @@ template <
|
|||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD>
|
short N_PER_THREAD>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||||
|
mb_block_partition(
|
||||||
device idx_t* block_partitions [[buffer(0)]],
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals [[buffer(1)]],
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs [[buffer(2)]],
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
@ -601,7 +592,6 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
val_t,
|
val_t,
|
||||||
idx_t,
|
idx_t,
|
||||||
@ -627,14 +617,9 @@ template <
|
|||||||
|
|
||||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||||
int partition = sort_kernel::merge_partition(
|
int partition = sort_kernel::merge_partition(
|
||||||
dev_vals + A_st,
|
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||||
dev_vals + B_st,
|
|
||||||
A_ed - A_st,
|
|
||||||
B_ed - B_st,
|
|
||||||
partition_at);
|
|
||||||
|
|
||||||
block_partitions[lid.x] = A_st + partition;
|
block_partitions[lid.x] = A_st + partition;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -644,7 +629,8 @@ template <
|
|||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD,
|
short N_PER_THREAD,
|
||||||
typename CompareOp = LessThan<val_t>>
|
typename CompareOp = LessThan<val_t>>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||||
|
mb_block_merge(
|
||||||
const device idx_t* block_partitions [[buffer(0)]],
|
const device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals_in [[buffer(1)]],
|
const device val_t* dev_vals_in [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||||
@ -655,7 +641,6 @@ template <
|
|||||||
const constant int& num_tiles [[buffer(7)]],
|
const constant int& num_tiles [[buffer(7)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
val_t,
|
val_t,
|
||||||
idx_t,
|
idx_t,
|
||||||
@ -681,7 +666,9 @@ template <
|
|||||||
int A_st = block_partitions[block_idx + 0];
|
int A_st = block_partitions[block_idx + 0];
|
||||||
int A_ed = block_partitions[block_idx + 1];
|
int A_ed = block_partitions[block_idx + 1];
|
||||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||||
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
int B_ed = min(
|
||||||
|
size_sorted_axis,
|
||||||
|
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||||
|
|
||||||
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||||
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
@ -697,8 +684,10 @@ template <
|
|||||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||||
int idx = BLOCK_THREADS * i + lid.x;
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
if (idx < (A_sz + B_sz)) {
|
if (idx < (A_sz + B_sz)) {
|
||||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
: dev_vals_in[B_st + idx - A_sz];
|
||||||
|
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||||
|
: dev_idxs_in[B_st + idx - A_sz];
|
||||||
} else {
|
} else {
|
||||||
thread_vals[i] = CompareOp::init;
|
thread_vals[i] = CompareOp::init;
|
||||||
thread_idxs[i] = 0;
|
thread_idxs[i] = 0;
|
||||||
@ -720,11 +709,7 @@ template <
|
|||||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||||
|
|
||||||
int A_st_local = block_sort_t::merge_partition(
|
int A_st_local = block_sort_t::merge_partition(
|
||||||
tgp_vals,
|
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||||
tgp_vals + A_sz,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
sort_md_local);
|
|
||||||
int A_ed_local = A_sz;
|
int A_ed_local = A_sz;
|
||||||
|
|
||||||
int B_st_local = sort_md_local - A_st_local;
|
int B_st_local = sort_md_local - A_st_local;
|
||||||
@ -761,12 +746,13 @@ template <
|
|||||||
dev_idxs_out[idx] = tgp_idxs[i];
|
dev_idxs_out[idx] = tgp_idxs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
#define instantiate_multi_block_sort( \
|
||||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||||
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
|
||||||
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||||
const device vtype* inp [[buffer(0)]], \
|
const device vtype* inp [[buffer(0)]], \
|
||||||
device vtype* out_vals [[buffer(1)]], \
|
device vtype* out_vals [[buffer(1)]], \
|
||||||
device itype* out_idxs [[buffer(2)]], \
|
device itype* out_idxs [[buffer(2)]], \
|
||||||
@ -777,8 +763,9 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]], \
|
const device size_t* nc_strides [[buffer(7)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
|
||||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||||
device itype * block_partitions [[buffer(0)]], \
|
device itype * block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals [[buffer(1)]], \
|
const device vtype* dev_vals [[buffer(1)]], \
|
||||||
const device itype* dev_idxs [[buffer(2)]], \
|
const device itype* dev_idxs [[buffer(2)]], \
|
||||||
@ -787,8 +774,9 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
|
||||||
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||||
const device itype* block_partitions [[buffer(0)]], \
|
const device itype* block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||||
@ -800,6 +788,7 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||||
|
|
||||||
@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
|
|||||||
instantiate_multi_block_sort_base(int32, int32_t)
|
instantiate_multi_block_sort_base(int32, int32_t)
|
||||||
instantiate_multi_block_sort_base(float16, half)
|
instantiate_multi_block_sort_base(float16, half)
|
||||||
instantiate_multi_block_sort_base(float32, float)
|
instantiate_multi_block_sort_base(float32, float)
|
||||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||||
|
|
||||||
instantiate_multi_block_sort_long(uint64, uint64_t)
|
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||||
instantiate_multi_block_sort_long(int64, int64_t)
|
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
|
@ -4,13 +4,14 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -18,7 +19,8 @@ template <typename T,
|
|||||||
int WN,
|
int WN,
|
||||||
int N_CHANNELS = 0,
|
int N_CHANNELS = 0,
|
||||||
bool SMALL_FILTER = false>
|
bool SMALL_FILTER = false>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T* C [[buffer(2)]],
|
device T* C [[buffer(2)]],
|
||||||
@ -28,8 +30,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
@ -56,7 +56,13 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small channel specialization
|
// Go to small channel specialization
|
||||||
Conv2DInputBlockLoaderSmallChannels<
|
Conv2DInputBlockLoaderSmallChannels<
|
||||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
// Else go to general loader
|
// Else go to general loader
|
||||||
typename metal::conditional_t<
|
typename metal::conditional_t<
|
||||||
@ -65,14 +71,21 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small filter specialization
|
// Go to small filter specialization
|
||||||
Conv2DInputBlockLoaderSmallFilter<
|
Conv2DInputBlockLoaderSmallFilter<
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
// Else go to large filter generalization
|
// Else go to large filter generalization
|
||||||
Conv2DInputBlockLoaderLargeFilter<
|
Conv2DInputBlockLoaderLargeFilter<
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>
|
T,
|
||||||
>
|
BM,
|
||||||
>;
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>>>;
|
||||||
|
|
||||||
// Weight loader
|
// Weight loader
|
||||||
using loader_b_t = typename metal::conditional_t<
|
using loader_b_t = typename metal::conditional_t<
|
||||||
@ -81,11 +94,16 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small channel specialization
|
// Go to small channel specialization
|
||||||
Conv2DWeightBlockLoaderSmallChannels<
|
Conv2DWeightBlockLoaderSmallChannels<
|
||||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_b>,
|
||||||
|
|
||||||
// Else go to general loader
|
// Else go to general loader
|
||||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
|
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||||
>;
|
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
using mma_t = BlockMMA<
|
||||||
T,
|
T,
|
||||||
@ -123,8 +141,10 @@ template <typename T,
|
|||||||
const int2 offsets_b(0, c_col);
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
loader_a_t loader_a(
|
||||||
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
@ -152,12 +172,24 @@ template <typename T,
|
|||||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
|
#define instantiate_implicit_conv_2d( \
|
||||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
|
name, \
|
||||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
channel_name, \
|
||||||
|
n_channels, \
|
||||||
|
filter_name, \
|
||||||
|
small_filter) \
|
||||||
|
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
|
||||||
|
"_filter_" #filter_name)]] [[kernel]] void \
|
||||||
|
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* C [[buffer(2)]], \
|
device itype* C [[buffer(2)]], \
|
||||||
@ -168,22 +200,25 @@ template <typename T,
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -4,15 +4,16 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -20,7 +21,8 @@ template <typename T,
|
|||||||
int WN,
|
int WN,
|
||||||
typename AccumType = float,
|
typename AccumType = float,
|
||||||
typename Epilogue = TransformNone<T, AccumType>>
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d_general(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T* C [[buffer(2)]],
|
device T* C [[buffer(2)]],
|
||||||
@ -33,7 +35,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
constexpr bool transpose_a = false;
|
constexpr bool transpose_a = false;
|
||||||
@ -51,12 +52,12 @@ template <typename T,
|
|||||||
constexpr short tgp_size = WM * WN * 32;
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
// Input loader
|
// Input loader
|
||||||
using loader_a_t = Conv2DInputBlockLoaderGeneral<
|
using loader_a_t =
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||||
|
|
||||||
// Weight loader
|
// Weight loader
|
||||||
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
|
using loader_b_t =
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
using mma_t = BlockMMA<
|
||||||
T,
|
T,
|
||||||
@ -103,13 +104,32 @@ template <typename T,
|
|||||||
const int2 offsets_b(0, c_col);
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
loader_a_t loader_a(
|
||||||
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
A,
|
||||||
|
As,
|
||||||
|
offsets_a,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B,
|
||||||
|
Bs,
|
||||||
|
offsets_b,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
int gemm_k_iterations =
|
||||||
|
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -143,22 +163,24 @@ template <typename T,
|
|||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < mma_t::TM; i++) {
|
for (int i = 0; i < mma_t::TM; i++) {
|
||||||
|
|
||||||
int cm = offset_m + i * mma_t::TM_stride;
|
int cm = offset_m + i * mma_t::TM_stride;
|
||||||
|
|
||||||
int n = cm / jump_params->adj_out_hw;
|
int n = cm / jump_params->adj_out_hw;
|
||||||
int hw = cm % jump_params->adj_out_hw;
|
int hw = cm % jump_params->adj_out_hw;
|
||||||
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
int oh =
|
||||||
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||||
|
int ow =
|
||||||
|
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||||
|
|
||||||
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||||
|
int offset_cm = n * params->out_strides[0] +
|
||||||
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
|
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (int j = 0; j < mma_t::TN; j++) {
|
for (int j = 0; j < mma_t::TN; j++) {
|
||||||
// Get accumulated result and associated offset in C
|
// Get accumulated result and associated offset in C
|
||||||
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
|
thread const auto& accum =
|
||||||
|
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||||
|
|
||||||
// Apply epilogue and output C
|
// Apply epilogue and output C
|
||||||
@ -170,16 +192,16 @@ template <typename T,
|
|||||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||||
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
template \
|
||||||
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
|
||||||
|
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* C [[buffer(2)]], \
|
device itype* C [[buffer(2)]], \
|
||||||
@ -196,14 +218,16 @@ template <typename T,
|
|||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
@ -11,7 +11,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -32,8 +33,18 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
@ -57,20 +68,34 @@ template <typename T,
|
|||||||
D += params->batch_stride_d * tid.z;
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
gemm_kernel::run(
|
gemm_kernel::run(
|
||||||
A, B, D,
|
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||||
params,
|
|
||||||
As, Bs,
|
|
||||||
simd_lane_id, simd_group_id, tid, lid
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
tname, \
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
iname, \
|
||||||
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||||
|
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||||
|
"_K_" #kname)]] [[kernel]] void \
|
||||||
|
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* D [[buffer(3)]], \
|
device itype* D [[buffer(3)]], \
|
||||||
@ -82,26 +107,30 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -10,7 +10,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -35,15 +36,23 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Pacifying compiler
|
// Pacifying compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel =
|
using gemm_kernel = GEMMKernel<
|
||||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
T,
|
||||||
transpose_a, transpose_b,
|
T,
|
||||||
MN_aligned, K_aligned,
|
BM,
|
||||||
AccumType, Epilogue>;
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
@ -59,7 +68,12 @@ template <typename T,
|
|||||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||||
|
|
||||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||||
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
A_bstrides,
|
||||||
|
B_bstrides,
|
||||||
|
C_bstrides,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
A += batch_offsets.x;
|
A += batch_offsets.x;
|
||||||
B += batch_offsets.y;
|
B += batch_offsets.y;
|
||||||
@ -140,7 +154,8 @@ template <typename T,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
mma_op.store_result(
|
||||||
|
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -164,7 +179,8 @@ template <typename T,
|
|||||||
leftover_bk,
|
leftover_bk,
|
||||||
LoopAlignment<true, true, K_aligned>{});
|
LoopAlignment<true, true, K_aligned>{});
|
||||||
|
|
||||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
mma_op.store_result(
|
||||||
|
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
} else if (tgp_bn == BN) {
|
||||||
@ -181,8 +197,11 @@ template <typename T,
|
|||||||
LoopAlignment<false, true, K_aligned>{});
|
LoopAlignment<false, true, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
|
|
||||||
@ -200,8 +219,11 @@ template <typename T,
|
|||||||
LoopAlignment<true, false, K_aligned>{});
|
LoopAlignment<true, false, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
|
|
||||||
@ -219,8 +241,11 @@ template <typename T,
|
|||||||
LoopAlignment<false, false, K_aligned>{});
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
}
|
}
|
||||||
@ -231,9 +256,41 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
tname, \
|
||||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
iname, \
|
||||||
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned, \
|
||||||
|
ep_name, \
|
||||||
|
epilogue) \
|
||||||
|
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||||
|
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||||
|
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
|
||||||
|
addmm< \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned, \
|
||||||
|
float, \
|
||||||
|
epilogue<itype, float>>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
const device itype* C [[buffer(2)]], \
|
const device itype* C [[buffer(2)]], \
|
||||||
@ -247,30 +304,35 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
@ -11,7 +11,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -22,7 +23,8 @@ template <typename T,
|
|||||||
bool MN_aligned,
|
bool MN_aligned,
|
||||||
bool K_aligned,
|
bool K_aligned,
|
||||||
bool has_operand_mask = false>
|
bool has_operand_mask = false>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
block_masked_gemm(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T* D [[buffer(3)]],
|
device T* D [[buffer(3)]],
|
||||||
@ -37,11 +39,21 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
@ -52,15 +64,23 @@ template <typename T,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
|
const constant size_t* mask_batch_strides =
|
||||||
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
batch_strides + 2 * params->batch_ndim;
|
||||||
|
out_mask +=
|
||||||
|
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||||
|
|
||||||
if (has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
|
const constant size_t* mask_strides_lhs =
|
||||||
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
|
mask_batch_strides + params->batch_ndim;
|
||||||
|
const constant size_t* mask_strides_rhs =
|
||||||
|
mask_strides_lhs + params->batch_ndim;
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
mask_strides_lhs,
|
||||||
|
mask_strides_rhs,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
lhs_mask += batch_offsets.x;
|
lhs_mask += batch_offsets.x;
|
||||||
rhs_mask += batch_offsets.y;
|
rhs_mask += batch_offsets.y;
|
||||||
@ -99,7 +119,6 @@ template <typename T,
|
|||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
D += c_row * params->ldd + c_col;
|
D += c_row * params->ldd + c_col;
|
||||||
|
|
||||||
|
|
||||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||||
|
|
||||||
// Write zeros and return
|
// Write zeros and return
|
||||||
@ -151,8 +170,10 @@ template <typename T,
|
|||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
thread typename gemm_kernel::loader_a_t loader_a(
|
||||||
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread typename gemm_kernel::loader_b_t loader_b(
|
||||||
|
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// MNK aligned loop
|
// MNK aligned loop
|
||||||
@ -161,9 +182,10 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
loader_b.load_unsafe();
|
loader_b.load_unsafe();
|
||||||
@ -172,7 +194,6 @@ template <typename T,
|
|||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
@ -184,11 +205,12 @@ template <typename T,
|
|||||||
|
|
||||||
// Loop tail
|
// Loop tail
|
||||||
if (!K_aligned) {
|
if (!K_aligned) {
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
@ -199,7 +221,6 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,9 +245,10 @@ template <typename T,
|
|||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
if (M_aligned) {
|
if (M_aligned) {
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
@ -244,7 +266,6 @@ template <typename T,
|
|||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
@ -256,9 +277,11 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
short2 tile_dims_A_last =
|
short2 tile_dims_A_last =
|
||||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||||
short2 tile_dims_B_last =
|
short2 tile_dims_B_last =
|
||||||
@ -270,7 +293,6 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,9 +308,41 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
|
tname, \
|
||||||
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
iname, \
|
||||||
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned, \
|
||||||
|
omname, \
|
||||||
|
op_mask) \
|
||||||
|
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
|
||||||
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||||
|
"_MN_" #aname "_K_" #kname \
|
||||||
|
"_op_mask_" #omname)]] [[kernel]] void \
|
||||||
|
block_masked_gemm< \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned, \
|
||||||
|
op_mask>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* D [[buffer(3)]], \
|
device itype* D [[buffer(3)]], \
|
||||||
@ -304,26 +358,31 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||||
|
@ -10,7 +10,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
@ -30,10 +31,20 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
@ -54,9 +65,12 @@ template <typename T,
|
|||||||
const int c_col = tid_x * BN;
|
const int c_col = tid_x * BN;
|
||||||
const int k_start = params->split_k_partition_size * tid_z;
|
const int k_start = params->split_k_partition_size * tid_z;
|
||||||
|
|
||||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
A += transpose_a ? (c_row + k_start * params->lda)
|
||||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
: (k_start + c_row * params->lda);
|
||||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
B += transpose_b ? (k_start + c_col * params->ldb)
|
||||||
|
: (c_col + k_start * params->ldb);
|
||||||
|
C += (params->split_k_partition_stride * tid_z) +
|
||||||
|
(c_row * params->ldc + c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
@ -124,7 +138,8 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
int gemm_k_iter_remaining =
|
||||||
|
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||||
if (!K_aligned || gemm_k_iter_remaining > 0)
|
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||||
gemm_kernel::gemm_loop(
|
gemm_kernel::gemm_loop(
|
||||||
As,
|
As,
|
||||||
@ -150,9 +165,38 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
tname, \
|
||||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
iname, \
|
||||||
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||||
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||||
|
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||||
|
gemm_splitk< \
|
||||||
|
itype, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device otype* C [[buffer(2)]], \
|
device otype* C [[buffer(2)]], \
|
||||||
@ -162,34 +206,39 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Split k accumulation kernel
|
// Split k accumulation kernel
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename AccT,
|
template <
|
||||||
|
typename AccT,
|
||||||
typename OutT,
|
typename OutT,
|
||||||
typename Epilogue = TransformNone<OutT, AccT>>
|
typename Epilogue = TransformNone<OutT, AccT>>
|
||||||
[[kernel]] void gemm_splitk_accum(
|
[[kernel]] void gemm_splitk_accum(
|
||||||
@ -199,7 +248,6 @@ template <typename AccT,
|
|||||||
const constant int& partition_stride [[buffer(3)]],
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
const constant int& ldd [[buffer(4)]],
|
const constant int& ldd [[buffer(4)]],
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Ajust D and C
|
// Ajust D and C
|
||||||
D += gid.x + gid.y * ldd;
|
D += gid.x + gid.y * ldd;
|
||||||
C_split += gid.x + gid.y * ldd;
|
C_split += gid.x + gid.y * ldd;
|
||||||
@ -214,10 +262,10 @@ template <typename AccT,
|
|||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
D[0] = Epilogue::apply(out);
|
D[0] = Epilogue::apply(out);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AccT,
|
template <
|
||||||
|
typename AccT,
|
||||||
typename OutT,
|
typename OutT,
|
||||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||||
[[kernel]] void gemm_splitk_accum_axpby(
|
[[kernel]] void gemm_splitk_accum_axpby(
|
||||||
@ -232,7 +280,6 @@ template <typename AccT,
|
|||||||
const constant float& alpha [[buffer(8)]],
|
const constant float& alpha [[buffer(8)]],
|
||||||
const constant float& beta [[buffer(9)]],
|
const constant float& beta [[buffer(9)]],
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Ajust D and C
|
// Ajust D and C
|
||||||
C += gid.x * fdc + gid.y * ldc;
|
C += gid.x * fdc + gid.y * ldc;
|
||||||
D += gid.x + gid.y * ldd;
|
D += gid.x + gid.y * ldd;
|
||||||
@ -249,20 +296,21 @@ template <typename AccT,
|
|||||||
// Write output
|
// Write output
|
||||||
Epilogue op(alpha, beta);
|
Epilogue op(alpha, beta);
|
||||||
D[0] = op.apply(out, *C);
|
D[0] = op.apply(out, *C);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_accum(oname, otype, aname, atype) \
|
#define instantiate_accum(oname, otype, aname, atype) \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
"_" #aname)]] [[kernel]] void \
|
||||||
|
gemm_splitk_accum<atype, otype>( \
|
||||||
const device atype* C_split [[buffer(0)]], \
|
const device atype* C_split [[buffer(0)]], \
|
||||||
device otype* D [[buffer(1)]], \
|
device otype* D [[buffer(1)]], \
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
const constant int& ldd [[buffer(4)]], \
|
const constant int& ldd [[buffer(4)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]); \
|
uint2 gid [[thread_position_in_grid]]); \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
"_axpby")]] [[kernel]] void \
|
||||||
|
gemm_splitk_accum_axpby<atype, otype>( \
|
||||||
const device atype* C_split [[buffer(0)]], \
|
const device atype* C_split [[buffer(0)]], \
|
||||||
device otype* D [[buffer(1)]], \
|
device otype* D [[buffer(1)]], \
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
@ -275,6 +323,7 @@ template <typename AccT,
|
|||||||
const constant float& beta [[buffer(9)]], \
|
const constant float& beta [[buffer(9)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||||
instantiate_accum(float16, half, float32, float);
|
instantiate_accum(float16, half, float32, float);
|
||||||
instantiate_accum(float32, float, float32, float);
|
instantiate_accum(float32, float, float32, float); // clang-format on
|
@ -3,9 +3,9 @@
|
|||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/ternary.h"
|
#include "mlx/backend/metal/kernels/ternary.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void ternary_op_v(
|
[[kernel]] void ternary_op_v(
|
||||||
@ -65,7 +65,8 @@ template <typename T, typename Op>
|
|||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
|
|||||||
constant const size_t c_strides[DIM],
|
constant const size_t c_strides[DIM],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
auto idx =
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,23 +102,22 @@ template <typename T, typename Op>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
auto idx =
|
||||||
|
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_ternary_v(name, type, op) \
|
#define instantiate_ternary_v(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
|
||||||
[[kernel]] void ternary_op_v<type, op>( \
|
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
device type* d, \
|
device type* d, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g(name, type, op) \
|
#define instantiate_ternary_g(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
|
||||||
[[kernel]] void ternary_op_g<type, op>( \
|
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -126,11 +128,11 @@ template <typename T, typename Op>
|
|||||||
constant const size_t* c_strides, \
|
constant const size_t* c_strides, \
|
||||||
constant const int& ndim, \
|
constant const int& ndim, \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
ternary_op_g_nd<type, op, dims>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -140,11 +142,11 @@ template <typename T, typename Op>
|
|||||||
constant const size_t b_strides[dims], \
|
constant const size_t b_strides[dims], \
|
||||||
constant const size_t c_strides[dims], \
|
constant const size_t c_strides[dims], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_nd(name, type, op) \
|
#define instantiate_ternary_g_nd(name, type, op) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
ternary_op_g_nd1<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -153,8 +155,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t& b_strides, \
|
constant const size_t& b_strides, \
|
||||||
constant const size_t& c_strides, \
|
constant const size_t& c_strides, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
ternary_op_g_nd2<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -164,8 +166,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t c_strides[2], \
|
constant const size_t c_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
ternary_op_g_nd3<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -176,13 +178,15 @@ template <typename T, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
instantiate_ternary_g_dim(name, type, op, 5)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_ternary_all(name, tname, type, op) \
|
#define instantiate_ternary_all(name, tname, type, op) \
|
||||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_ternary_types(name, op) \
|
#define instantiate_ternary_types(name, op) \
|
||||||
instantiate_ternary_all(name, bool_, bool, op) \
|
instantiate_ternary_all(name, bool_, bool, op) \
|
||||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||||
@ -196,6 +200,6 @@ template <typename T, typename Op>
|
|||||||
instantiate_ternary_all(name, float16, half, op) \
|
instantiate_ternary_all(name, float16, half, op) \
|
||||||
instantiate_ternary_all(name, float32, float, op) \
|
instantiate_ternary_all(name, float32, float, op) \
|
||||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
|
||||||
|
|
||||||
instantiate_ternary_types(select, Select)
|
instantiate_ternary_types(select, Select)
|
@ -23,15 +23,13 @@ template <typename T, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_unary_v(name, type, op) \
|
#define instantiate_unary_v(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
|
||||||
[[kernel]] void unary_op_v<type, op>( \
|
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_unary_g(name, type, op) \
|
#define instantiate_unary_g(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
|
||||||
[[kernel]] void unary_op_g<type, op>( \
|
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
device const int* in_shape, \
|
device const int* in_shape, \
|
||||||
@ -39,15 +37,18 @@ template <typename T, typename Op>
|
|||||||
device const int& ndim, \
|
device const int& ndim, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_all(name, tname, type, op) \
|
#define instantiate_unary_all(name, tname, type, op) \
|
||||||
instantiate_unary_v("v" #name #tname, type, op) \
|
instantiate_unary_v("v" #name #tname, type, op) \
|
||||||
instantiate_unary_g("g" #name #tname, type, op)
|
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_float(name, op) \
|
#define instantiate_unary_float(name, op) \
|
||||||
instantiate_unary_all(name, float16, half, op) \
|
instantiate_unary_all(name, float16, half, op) \
|
||||||
instantiate_unary_all(name, float32, float, op) \
|
instantiate_unary_all(name, float32, float, op) \
|
||||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_types(name, op) \
|
#define instantiate_unary_types(name, op) \
|
||||||
instantiate_unary_all(name, bool_, bool, op) \
|
instantiate_unary_all(name, bool_, bool, op) \
|
||||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||||
@ -58,8 +59,9 @@ template <typename T, typename Op>
|
|||||||
instantiate_unary_all(name, int16, int16_t, op) \
|
instantiate_unary_all(name, int16, int16_t, op) \
|
||||||
instantiate_unary_all(name, int32, int32_t, op) \
|
instantiate_unary_all(name, int32, int32_t, op) \
|
||||||
instantiate_unary_all(name, int64, int64_t, op) \
|
instantiate_unary_all(name, int64, int64_t, op) \
|
||||||
instantiate_unary_float(name, op)
|
instantiate_unary_float(name, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_unary_types(abs, Abs)
|
instantiate_unary_types(abs, Abs)
|
||||||
instantiate_unary_float(arccos, ArcCos)
|
instantiate_unary_float(arccos, ArcCos)
|
||||||
instantiate_unary_float(arccosh, ArcCosh)
|
instantiate_unary_float(arccosh, ArcCosh)
|
||||||
@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
|||||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||||
|
|
||||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on
|
||||||
|
Loading…
Reference in New Issue
Block a user