mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump * add guards * update * more guards * more guards * smakk fix * Refactor instantiation of ternary types in ternary.metal * fix scan.metal
This commit is contained in:
parent
8db7161c94
commit
a30e7ed2da
@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.3
|
||||
rev: v18.1.4
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.3.0
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@ -33,7 +33,7 @@ array axpby(
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
|
@ -19,7 +19,7 @@ template <typename T>
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
@ -31,30 +31,30 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
||||
[[kernel]] void axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
|
@ -14,7 +14,7 @@ class Buffer {
|
||||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr){};
|
||||
Buffer(void* ptr) : ptr_(ptr) {};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
@ -209,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d){};
|
||||
: buffer(buffer), d(d) {};
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
|
@ -38,7 +38,7 @@ using MTLFCList =
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false){};
|
||||
: enc(enc), concurrent(false) {};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
|
@ -11,22 +11,22 @@ template <typename T>
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] \
|
||||
[[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
// clang-format off
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint64, uint64_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int16, int16_t)
|
||||
instantiate_arange(int32, int32_t)
|
||||
instantiate_arange(int64, int64_t)
|
||||
instantiate_arange(float16, half)
|
||||
instantiate_arange(float32, float)
|
||||
instantiate_arange(bfloat16, bfloat16_t)
|
||||
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
@ -18,7 +18,8 @@ struct ArgMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
@ -26,11 +27,12 @@ struct ArgMin {
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
@ -42,7 +44,8 @@ struct ArgMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
@ -50,11 +53,12 @@ struct ArgMax {
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
IndexValPair<U>
|
||||
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
@ -64,19 +68,16 @@ struct ArgMax {
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>{
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
};
|
||||
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename Op, int N_READS>
|
||||
[[kernel]] void arg_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device uint32_t *out [[buffer(1)]],
|
||||
const device int *shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const device int* shape [[buffer(2)]],
|
||||
const device size_t* in_strides [[buffer(3)]],
|
||||
const device size_t* out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||
// and stride are provided in axis_stride and axis_size.
|
||||
//
|
||||
@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
|
||||
threadgroup IndexValPair<T> local_data[32];
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
// Read the current value
|
||||
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
||||
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
|
||||
uint32_t offset = current_index;
|
||||
const device T * current_in = in + in_idx + current_index * axis_stride;
|
||||
const device T* current_in = in + in_idx + current_index * axis_stride;
|
||||
T vals[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||
current_index++;
|
||||
current_in += axis_stride;
|
||||
@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
|
||||
// need to reduce across the thread group.
|
||||
|
||||
// First per simd reduction.
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
|
||||
if (simd_lane_id < simd_groups) {
|
||||
best = local_data[simd_lane_id];
|
||||
}
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
|
||||
}
|
||||
|
||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device uint32_t * out [[buffer(1)]], \
|
||||
const device int *shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device uint32_t* out [[buffer(1)]], \
|
||||
const device int* shape [[buffer(2)]], \
|
||||
const device size_t* in_strides [[buffer(3)]], \
|
||||
const device size_t* out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||
|
||||
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
@ -112,114 +114,118 @@ template <typename T, typename U, typename Op>
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
instantiate_binary_float(name, op) // clang-format on
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift)
|
||||
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||
|
@ -3,28 +3,42 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
@ -32,10 +46,11 @@ struct Remainder {
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
return r;
|
||||
}
|
||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
||||
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||
|
@ -22,7 +22,7 @@ struct complex64_t {
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
|
@ -1,13 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
@ -23,14 +21,15 @@ template <typename T, int N>
|
||||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.y * (params->C);
|
||||
|
||||
// Coordinates in input
|
||||
@ -46,11 +45,11 @@ template <typename T, int N>
|
||||
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
// Unroll dimensions
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
||||
@ -64,10 +63,10 @@ template <typename T, int N>
|
||||
wS /= params->wS[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
if (valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
@ -85,12 +84,13 @@ template <typename T, int N>
|
||||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
for (short i = 0; i < N; i++)
|
||||
out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
||||
@ -128,10 +128,10 @@ template <typename T, int N>
|
||||
out += ws_ * params->str[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
if (valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
@ -141,24 +141,24 @@ template <typename T, int N>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_transpose_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
|
||||
naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]); \
|
||||
template \
|
||||
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
|
||||
naive_unfold_transpose_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
||||
instantiate_naive_unfold_nd(name, itype, 3)
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
|
||||
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
|
||||
|
||||
instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
@ -168,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
||||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
@ -183,7 +184,6 @@ template <typename T,
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
@ -192,80 +192,82 @@ template <typename T,
|
||||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for(int h = 0; h < params.wS[0]; ++h) {
|
||||
for(int w = 0; w < params.wS[1]; ++w) {
|
||||
for(int c = 0; c < params.C; ++c) {
|
||||
|
||||
for (int h = 0; h < params.wS[0]; ++h) {
|
||||
for (int w = 0; w < params.wS[1]; ++w) {
|
||||
for (int c = 0; c < params.C; ++c) {
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
for (int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
in_local[m] = valid
|
||||
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
||||
h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c] : T(0);
|
||||
wt_local[n] = o < params.O
|
||||
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||
(out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||
out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int M, int R, int S>
|
||||
struct WinogradTransforms {
|
||||
|
||||
};
|
||||
struct WinogradTransforms {};
|
||||
|
||||
template <>
|
||||
struct WinogradTransforms<6, 3, 8> {
|
||||
@ -287,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
|
||||
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00, 0.00, 0.00},
|
||||
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
||||
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
||||
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
||||
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
||||
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
||||
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
||||
{ 0.00, 0.00, 1.00},
|
||||
{1.00, 0.00, 0.00},
|
||||
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
|
||||
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
|
||||
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
|
||||
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
|
||||
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
|
||||
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
|
||||
{0.00, 0.00, 1.00},
|
||||
};
|
||||
};
|
||||
|
||||
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||
|
||||
template <typename T,
|
||||
int BC = 32,
|
||||
int BO = 4,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
||||
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||
winograd_conv_2d_weight_transform(
|
||||
const device T* wt_in [[buffer(0)]],
|
||||
device T* wt_out [[buffer(1)]],
|
||||
const constant int& C [[buffer(2)]],
|
||||
@ -337,7 +334,6 @@ template <typename T,
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
@ -357,35 +353,37 @@ template <typename T,
|
||||
|
||||
// Move to the correct output filter
|
||||
size_t ko = BO * tid + simd_group_id;
|
||||
wt_in += ko * R * R * C;
|
||||
wt_in += ko * R * R * C;
|
||||
|
||||
// wt_out is stored transposed (A x A x C x O)
|
||||
short ohw_0 = sm * 8 + sn;
|
||||
short ohw_1 = sm * 8 + sn + 1;
|
||||
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
|
||||
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
||||
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Ws[BO][R][R][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < C; bc += BC) {
|
||||
for (int bc = 0; bc < C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int kh = 0; kh < R; ++kh) {
|
||||
for(int kw = 0; kw < R; ++kw) {
|
||||
for(int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||
for (int kh = 0; kh < R; ++kh) {
|
||||
for (int kw = 0; kw < R; ++kw) {
|
||||
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = 0; c < BC; ++c) {
|
||||
// Do transform and store the result
|
||||
for (int c = 0; c < BC; ++c) {
|
||||
simdgroup_matrix<T, 8, 8> g;
|
||||
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
g.thread_elements()[0] =
|
||||
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] =
|
||||
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
|
||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||
@ -396,27 +394,23 @@ template <typename T,
|
||||
wt_out_0 += BC * O;
|
||||
wt_out_1 += BC * O;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
||||
const device itype* wt_in [[buffer(0)]],\
|
||||
device itype* wt_out [[buffer(1)]],\
|
||||
const constant int& C [[buffer(2)]],\
|
||||
const constant int& O [[buffer(3)]],\
|
||||
uint tid [[threadgroup_position_in_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
||||
"_bc" #bc)]] [[kernel]] void \
|
||||
winograd_conv_2d_weight_transform<itype, bc>( \
|
||||
const device itype* wt_in [[buffer(0)]], \
|
||||
device itype* wt_out [[buffer(1)]], \
|
||||
const constant int& C [[buffer(2)]], \
|
||||
const constant int& O [[buffer(3)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BC,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
||||
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
winograd_conv_2d_input_transform(
|
||||
const device T* inp_in [[buffer(0)]],
|
||||
device T* inp_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
@ -425,7 +419,6 @@ template <typename T,
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
@ -456,46 +449,48 @@ template <typename T,
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
inp_in += tid.z * params.in_strides[0]
|
||||
+ bh * params.in_strides[1]
|
||||
+ bw * params.in_strides[2];
|
||||
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||
bw * params.in_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||
}
|
||||
}
|
||||
|
||||
// inp_out is stored interleaved (A x A x tiles x C)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t tile_id =
|
||||
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||
device T* inp_out_0 =
|
||||
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||
device T* inp_out_1 =
|
||||
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Is[A][A][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < params.C; bc += BC) {
|
||||
for (int bc = 0; bc < params.C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BC; c += 32) {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> I;
|
||||
I.thread_elements()[0] = Is[sm][sn][c];
|
||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||
@ -509,28 +504,24 @@ template <typename T,
|
||||
inp_out_0 += BC;
|
||||
inp_out_1 += BC;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
||||
const device itype* inp_in [[buffer(0)]],\
|
||||
device itype* inp_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
template [[host_name("winograd_conv_2d_input_transform_" #name \
|
||||
"_bc" #bc)]] [[kernel]] void \
|
||||
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
|
||||
const device itype* inp_in [[buffer(0)]], \
|
||||
device itype* inp_out [[buffer(1)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BO,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
||||
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
winograd_conv_2d_output_transform(
|
||||
const device T* out_in [[buffer(0)]],
|
||||
device T* out_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
@ -539,7 +530,6 @@ template <typename T,
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
@ -572,57 +562,59 @@ template <typename T,
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
out_out += tid.z * params.out_strides[0]
|
||||
+ bh * params.out_strides[1]
|
||||
+ bw * params.out_strides[2];
|
||||
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
|
||||
bw * params.out_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||
jump_in[h][w] =
|
||||
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// out_in is stored interleaved (A x A x tiles x O)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t tile_id =
|
||||
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||
const device T* out_in_0 =
|
||||
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||
const device T* out_in_1 =
|
||||
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Os[M][M][BO];
|
||||
|
||||
// Loop over O
|
||||
for(int bo = 0; bo < params.O; bo += BO) {
|
||||
|
||||
for (int bo = 0; bo < params.O; bo += BO) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> O_mat;
|
||||
O_mat.thread_elements()[0] = out_in_0[c];
|
||||
O_mat.thread_elements()[1] = out_in_1[c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
if((sm < M) && (sn < M)) {
|
||||
if ((sm < M) && (sn < M)) {
|
||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||
}
|
||||
if((sm < M) && ((sn + 1) < M)) {
|
||||
if ((sm < M) && ((sn + 1) < M)) {
|
||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read out from shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
if(jump_in[h][w] >= 0) {
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
if (jump_in[h][w] >= 0) {
|
||||
device T* out_ptr = out_out + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BO; c += 32) {
|
||||
for (int c = simd_lane_id; c < BO; c += 32) {
|
||||
out_ptr[c] = Os[kh + h][kw + w][c];
|
||||
}
|
||||
}
|
||||
@ -633,25 +625,27 @@ template <typename T,
|
||||
out_in_0 += BO;
|
||||
out_in_1 += BO;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
||||
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
||||
const device itype* out_in [[buffer(0)]],\
|
||||
device itype* out_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
template [[host_name("winograd_conv_2d_output_transform_" #name \
|
||||
"_bo" #bo)]] [[kernel]] void \
|
||||
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
|
||||
const device itype* out_in [[buffer(0)]], \
|
||||
device itype* out_out [[buffer(1)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_winograd_conv_2d(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_winograd_conv_2d(name, itype) \
|
||||
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
@ -49,7 +49,8 @@ template <typename T, typename U>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -76,7 +78,8 @@ template <typename T, typename U>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -143,116 +146,110 @@ template <typename T, typename U>
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
// clang-format off
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_all(itname ##float16, itype, half) \
|
||||
instantiate_copy_all(itname ##float32, itype, float) \
|
||||
// clang-format off
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_all(itname ##float16, itype, half) \
|
||||
instantiate_copy_all(itname ##float32, itype, float) \
|
||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
||||
|
||||
@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
|
||||
instantiate_copy_itype(float16, half)
|
||||
instantiate_copy_itype(float32, float)
|
||||
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||
instantiate_copy_itype(complex64, complex64_t)
|
||||
instantiate_copy_itype(complex64, complex64_t) // clang-format on
|
||||
|
@ -6,9 +6,8 @@
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_math>
|
||||
#include <metal_common>
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
@ -23,7 +22,7 @@ float2 complex_mul(float2 a, float2 b) {
|
||||
}
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2*p);
|
||||
float theta = -1.0f * k * M_PI_F / (2 * p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
void radix2(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
@ -53,11 +57,16 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
void radix4(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2*m];
|
||||
float2 x_3 = read_buf[i + 3*m];
|
||||
float2 x_2 = read_buf[i + 2 * m];
|
||||
float2 x_3 = read_buf[i + 3 * m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
@ -90,11 +99,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2*p] = y_2;
|
||||
write_buf[j + 3*p] = y_3;
|
||||
write_buf[j + 2 * p] = y_2;
|
||||
write_buf[j + 3 * p] = y_3;
|
||||
}
|
||||
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
@ -107,11 +115,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2 *in [[buffer(0)]],
|
||||
device float2 * out [[buffer(1)]],
|
||||
const device float2* in [[buffer(0)]],
|
||||
device float2* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
@ -132,16 +139,16 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
|
||||
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
|
||||
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
|
||||
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m*2, read_buf, write_buf);
|
||||
radix2(i + m, p, m*2, read_buf, write_buf);
|
||||
radix2(i, p, m * 2, read_buf, write_buf);
|
||||
radix2(i + m, p, m * 2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -167,29 +174,26 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
|
||||
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
|
||||
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
|
||||
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] \
|
||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||
fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
// clang-format off
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1)
|
||||
instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2)
|
||||
instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3)
|
||||
instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5)
|
||||
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||
|
@ -14,17 +14,16 @@ using namespace metal;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T *src [[buffer(0)]],
|
||||
device T *out [[buffer(1)]],
|
||||
const constant int *src_shape [[buffer(2)]],
|
||||
const constant size_t *src_strides [[buffer(3)]],
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int *slice_sizes [[buffer(5)]],
|
||||
const constant int *axes [[buffer(6)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
@ -43,93 +42,78 @@ METAL_FUNC void gather_impl(
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T *src [[buffer(0)]], \
|
||||
device T *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T* src [[buffer(0)]], \
|
||||
device T* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
||||
|
||||
make_gather(0)
|
||||
make_gather(1)
|
||||
make_gather(2)
|
||||
make_gather(3)
|
||||
make_gather(4)
|
||||
make_gather(5)
|
||||
make_gather(6)
|
||||
make_gather(7)
|
||||
make_gather(8)
|
||||
make_gather(9)
|
||||
make_gather(10)
|
||||
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
||||
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t *src [[buffer(0)]], \
|
||||
device src_t *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
||||
gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t* src [[buffer(0)]], \
|
||||
device src_t* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
@ -148,29 +132,31 @@ instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
// clang-format off
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@ -18,32 +18,36 @@ using namespace metal;
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN , /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVKernel {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
@ -52,7 +56,7 @@ struct GEMVKernel {
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
@ -64,14 +68,13 @@ struct GEMVKernel {
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
|
||||
|
||||
// Thread local accumulation results
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
@ -80,7 +83,7 @@ struct GEMVKernel {
|
||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if(out_row >= out_vec_size)
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
@ -90,89 +93,80 @@ struct GEMVKernel {
|
||||
mat += out_row * marix_ld;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
|
||||
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
if (simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
|
||||
// Load for the row
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
int col_idx =
|
||||
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if(simd_lid == 0) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_row + tm] =
|
||||
static_cast<T>(alpha) * result[tm] +
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -180,40 +174,43 @@ struct GEMVKernel {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
@ -225,8 +222,7 @@ struct GEMVTKernel {
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
// Appease compiler
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
@ -243,77 +239,69 @@ struct GEMVTKernel {
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
for (; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
if (bm + TM <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
for (int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TN; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
if (lid.y == 0 && out_col < out_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 1; i < BM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_col + j] =
|
||||
static_cast<T>(alpha) * result[j] +
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
|
||||
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_col + j] = result[j];
|
||||
@ -328,18 +316,18 @@ struct GEMVTKernel {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
@ -355,16 +343,15 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
if(kDoNCBatch) {
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
if (kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
@ -372,67 +359,64 @@ template <
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
if (kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
|
||||
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
@ -443,18 +427,18 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
@ -470,16 +454,15 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
if(kDoNCBatch) {
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
if (kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
@ -487,70 +470,72 @@ template <
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
if (kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
out[i] =
|
||||
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
out[r + i] =
|
||||
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
|
||||
b[b_stride * (i + r)];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -323,16 +326,18 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
||||
instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@ -15,30 +15,31 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,33 +54,35 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
for (int i = N; i < values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
for (int i = N; i < values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
@ -89,7 +92,7 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
for (int i = N; i < values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
@ -98,29 +101,36 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline U qdot(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
U bias,
|
||||
U sum) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
accum +=
|
||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline U qdot_safe(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
U bias,
|
||||
U sum,
|
||||
int N) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
accum +=
|
||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
@ -170,16 +188,19 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline void
|
||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,10 +208,10 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,27 +223,38 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
||||
}
|
||||
|
||||
template <typename U, int N, int bits>
|
||||
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
inline void
|
||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
|
||||
U s[4] = {
|
||||
scale,
|
||||
scale / static_cast<U>(4.0f),
|
||||
scale / static_cast<U>(16.0f),
|
||||
scale / static_cast<U>(64.0f)};
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
w_local[4*i] = s[0] * (w[i] & 0x03) + bias;
|
||||
w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias;
|
||||
w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias;
|
||||
w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias;
|
||||
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
||||
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
||||
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
|
||||
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
|
||||
U s[4] = {
|
||||
scale,
|
||||
scale / static_cast<U>(16.0f),
|
||||
scale / static_cast<U>(256.0f),
|
||||
scale / static_cast<U>(4096.0f)};
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||
w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||
w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias;
|
||||
w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias;
|
||||
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||
w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias;
|
||||
w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias;
|
||||
}
|
||||
}
|
||||
|
||||
@ -243,13 +275,20 @@ template <
|
||||
short group_size,
|
||||
short bits>
|
||||
struct QuantizedBlockLoader {
|
||||
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
|
||||
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
const int src_ld;
|
||||
@ -275,7 +314,8 @@ struct QuantizedBlockLoader {
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
||||
tile_stride(
|
||||
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
@ -293,8 +333,9 @@ struct QuantizedBlockLoader {
|
||||
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i=0; i<n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,14 +345,14 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
|
||||
for (int i=0; i<n_reads*pack_factor; i++) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
|
||||
for (int i=0; i<n_reads*pack_factor; i++) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
@ -319,8 +360,9 @@ struct QuantizedBlockLoader {
|
||||
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i=0; i<n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
@ -357,7 +399,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
@ -373,7 +414,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
@ -384,7 +426,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
@ -407,7 +450,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@ -420,7 +462,6 @@ template <typename T, const int group_size, const int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
@ -437,7 +478,8 @@ template <typename T, const int group_size, const int bits>
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||
|
||||
if (out_row >= out_vec_size) {
|
||||
@ -454,17 +496,19 @@ template <typename T, const int group_size, const int bits>
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
@ -472,11 +516,16 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
const int remaining = clamp(
|
||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||
0,
|
||||
values_per_thread);
|
||||
U sum =
|
||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
@ -502,17 +551,19 @@ template <typename T, const int group_size, const int bits>
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
@ -520,17 +571,23 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
const int remaining = clamp(
|
||||
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||
0,
|
||||
values_per_thread);
|
||||
U sum =
|
||||
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@ -542,7 +599,6 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
@ -555,7 +611,6 @@ template <typename T, const int group_size, const int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
@ -590,7 +645,8 @@ template <typename T, const int group_size, const int bits>
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
qouter<U, pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
@ -603,25 +659,32 @@ template <typename T, const int group_size, const int bits>
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
qouter<U, pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@ -635,7 +698,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
@ -647,9 +709,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
|
||||
using mma_t = mlx::steel::
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = QuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
@ -675,7 +747,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
if (num_els < BM) {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
@ -685,7 +757,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
@ -697,7 +769,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
} else {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
@ -707,7 +779,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
@ -728,8 +800,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@ -743,7 +820,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
@ -756,9 +832,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
|
||||
using mma_t = mlx::steel::
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
||||
using loader_x_t = mlx::steel::
|
||||
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||
using loader_w_t = QuantizedBlockLoader<
|
||||
T,
|
||||
BK,
|
||||
BN,
|
||||
BN_padded,
|
||||
0,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
@ -780,8 +866,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
if (num_els < BM) {
|
||||
if ((K % BK) != 0) {
|
||||
const int k_blocks = K/BK;
|
||||
for (int k=0; k<k_blocks; k++) {
|
||||
const int k_blocks = K / BK;
|
||||
for (int k = 0; k < k_blocks; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
@ -797,7 +883,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
} else {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
@ -809,8 +895,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
} else {
|
||||
if ((K % BK) != 0) {
|
||||
const int k_blocks = K/BK;
|
||||
for (int k=0; k<k_blocks; k++) {
|
||||
const int k_blocks = K / BK;
|
||||
for (int k = 0; k < k_blocks; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
@ -826,7 +912,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
} else {
|
||||
for (int k=0; k<K; k += BK) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
@ -847,26 +933,28 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
|
||||
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||
"_fast")]] [[kernel]] void \
|
||||
qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||
// clang-format off
|
||||
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
|
||||
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_qmv_fast_types(128, 2, 1)
|
||||
instantiate_qmv_fast_types(128, 4, 2)
|
||||
instantiate_qmv_fast_types(128, 8, 2)
|
||||
@ -875,27 +963,30 @@ instantiate_qmv_fast_types( 64, 4, 2)
|
||||
instantiate_qmv_fast_types( 64, 8, 2)
|
||||
instantiate_qmv_fast_types( 32, 2, 1)
|
||||
instantiate_qmv_fast_types( 32, 4, 2)
|
||||
instantiate_qmv_fast_types( 32, 8, 2)
|
||||
instantiate_qmv_fast_types( 32, 8, 2) // clang-format on
|
||||
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
qmv<itype, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
// clang-format off
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
instantiate_qmv(float32, float, group_size, bits) \
|
||||
instantiate_qmv(float16, half, group_size, bits) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
||||
instantiate_qmv(float16, half, group_size, bits) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_qmv_types(128, 2)
|
||||
instantiate_qmv_types(128, 4)
|
||||
instantiate_qmv_types(128, 8)
|
||||
@ -904,27 +995,30 @@ instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
instantiate_qmv_types( 32, 2)
|
||||
instantiate_qmv_types( 32, 4)
|
||||
instantiate_qmv_types( 32, 8)
|
||||
instantiate_qmv_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qvm_types(group_size, bits) \
|
||||
// clang-format off
|
||||
#define instantiate_qvm_types(group_size, bits) \
|
||||
instantiate_qvm(float32, float, group_size, bits) \
|
||||
instantiate_qvm(float16, half, group_size, bits) \
|
||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
|
||||
instantiate_qvm(float16, half, group_size, bits) \
|
||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_qvm_types(128, 2)
|
||||
instantiate_qvm_types(128, 4)
|
||||
instantiate_qvm_types(128, 8)
|
||||
@ -933,32 +1027,35 @@ instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
instantiate_qvm_types( 32, 2)
|
||||
instantiate_qvm_types( 32, 4)
|
||||
instantiate_qvm_types( 32, 8)
|
||||
instantiate_qvm_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \
|
||||
"_alN_" #aligned_N)]] [[kernel]] void \
|
||||
qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||
// clang-format off
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
|
||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
instantiate_qmm_t_types(128, 8)
|
||||
@ -967,29 +1064,32 @@ instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
instantiate_qmm_t_types( 32, 2)
|
||||
instantiate_qmm_t_types( 32, 4)
|
||||
instantiate_qmm_t_types( 32, 8)
|
||||
instantiate_qmm_t_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_n_types(group_size, bits) \
|
||||
// clang-format off
|
||||
#define instantiate_qmm_n_types(group_size, bits) \
|
||||
instantiate_qmm_n(float32, float, group_size, bits) \
|
||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
|
||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_qmm_n_types(128, 2)
|
||||
instantiate_qmm_n_types(128, 4)
|
||||
instantiate_qmm_n_types(128, 8)
|
||||
@ -998,4 +1098,4 @@ instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8)
|
||||
instantiate_qmm_n_types( 32, 8) // clang-format on
|
||||
|
@ -3,9 +3,8 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
static constexpr constant uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}
|
||||
};
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}};
|
||||
|
||||
union rbits {
|
||||
uint2 val;
|
||||
@ -13,7 +12,6 @@ union rbits {
|
||||
};
|
||||
|
||||
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
|
||||
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||
|
||||
rbits v;
|
||||
@ -51,7 +49,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
@ -87,7 +85,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -60,14 +60,13 @@ METAL_FUNC U per_thread_all_reduce(
|
||||
// All reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
@ -98,10 +97,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64
|
||||
// types)
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
@ -138,31 +140,31 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -170,11 +172,12 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name ##tname, type, type, op<type>)
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -12,8 +12,8 @@ using namespace metal;
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
|
||||
// Appease the compiler
|
||||
(void)out_size;
|
||||
|
||||
@ -35,15 +34,16 @@ template <typename T, typename U, typename Op>
|
||||
auto out_idx = tid;
|
||||
|
||||
in += elem_to_loc(
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
|
||||
for(uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
for (uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx =
|
||||
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
|
||||
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
U val = static_cast<U>(in[in_idx]);
|
||||
total_val = op(total_val, val);
|
||||
}
|
||||
@ -52,21 +52,21 @@ template <typename T, typename U, typename Op>
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] \
|
||||
[[kernel]] void col_reduce_small<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -112,39 +112,35 @@ METAL_FUNC U _contiguous_strided_reduce(
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
|
||||
Op op;
|
||||
if(out_idx < out_size) {
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
@ -153,40 +149,36 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
@ -195,52 +187,56 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] \
|
||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -12,22 +12,21 @@ using namespace metal;
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T *out [[buffer(0)]],
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] \
|
||||
[[kernel]] void init_reduce<otype, op>( \
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>)
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_init_reduce(orbool_, bool, Or)
|
||||
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -13,8 +13,8 @@ using namespace metal;
|
||||
// Each thread reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_small(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
@ -22,22 +22,21 @@ template <typename T, typename U, typename Op>
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
|
||||
uint out_idx = lid;
|
||||
|
||||
if(out_idx >= out_size) {
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
for(short r = 0; r < short(non_row_reductions); r++) {
|
||||
for (short r = 0; r < short(non_row_reductions); r++) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = 0; i < short(reduction_size); i++) {
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
@ -48,8 +47,8 @@ template <typename T, typename U, typename Op>
|
||||
// Each simdgroup reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_med(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
@ -60,45 +59,42 @@ template <typename T, typename U, typename Op>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
|
||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||
|
||||
if(out_idx >= out_size) {
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
if(short(non_row_reductions) == 1) {
|
||||
if (short(non_row_reductions) == 1) {
|
||||
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
else if (short(non_row_reductions) >= 32) {
|
||||
|
||||
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
|
||||
|
||||
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for(short i = 0; i < short(reduction_size); i++) {
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
|
||||
const short n_reductions = short(reduction_size) * short(non_row_reductions);
|
||||
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
|
||||
const short n_reductions =
|
||||
short(reduction_size) * short(non_row_reductions);
|
||||
const short reductions_per_thread =
|
||||
(n_reductions + simd_size - 1) / simd_size;
|
||||
|
||||
const short r_st = simd_lane_id / reductions_per_thread;
|
||||
const short r_ed = short(non_row_reductions);
|
||||
@ -108,54 +104,50 @@ template <typename T, typename U, typename Op>
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
if(r_st < r_jump) {
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
|
||||
if (r_st < r_jump) {
|
||||
for (short r = r_st; r < r_ed; r += r_jump) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
for (short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if(simd_lane_id == 0) {
|
||||
if (simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template[[host_name("row_reduce_general_small_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template[[host_name("row_reduce_general_med_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -217,10 +209,10 @@ METAL_FUNC U per_thread_row_reduce(
|
||||
return total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
@ -233,25 +225,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
if (reduction_size > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
@ -261,10 +261,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
@ -278,16 +278,24 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
|
||||
@ -299,9 +307,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
if (ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
}
|
||||
@ -311,61 +319,60 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
instantiate_rms(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
instantiate_rms(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@ -7,8 +7,8 @@
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const int& offset,
|
||||
@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
@ -57,18 +60,19 @@ template <typename T, bool traditional, bool forward>
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
@ -1,451 +1,551 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
using namespace metal;
|
||||
|
||||
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
|
||||
const device T *K [[buffer(1)]],
|
||||
const device T *V [[buffer(2)]],
|
||||
const device uint64_t& L [[buffer(3)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
||||
device float* O_partials [[buffer(5)]],
|
||||
device float* p_lse [[buffer(6)]],
|
||||
device float* p_maxes [[buffer(7)]],
|
||||
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr const size_t DK = 128;
|
||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
||||
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
||||
uint kv_head_offset_factor = tid.x;
|
||||
if(is_gqa) {
|
||||
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||
template <
|
||||
typename T,
|
||||
typename T2,
|
||||
typename T4,
|
||||
uint16_t TILE_SIZE_CONST,
|
||||
uint16_t NSIMDGROUPS>
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
const device uint64_t& L [[buffer(3)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
||||
device float* O_partials [[buffer(5)]],
|
||||
device float* p_lse [[buffer(6)]],
|
||||
device float* p_maxes [[buffer(7)]],
|
||||
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr const size_t DK = 128;
|
||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
||||
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
||||
uint kv_head_offset_factor = tid.x;
|
||||
if (is_gqa) {
|
||||
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||
}
|
||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
|
||||
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
|
||||
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
|
||||
NSIMDGROUPS;
|
||||
|
||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint i = 0; i < 8; i++) {
|
||||
smemFlush
|
||||
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
|
||||
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// TODO: multiple query sequence length for speculative decoding
|
||||
const uint tgroup_query_head_offset =
|
||||
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||
|
||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||
|
||||
const device T* baseK =
|
||||
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||
|
||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||
|
||||
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
||||
float threadAccum[ACCUM_PER_GROUP];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
|
||||
threadAccumIndex++) {
|
||||
threadAccum[threadAccumIndex] = -INFINITY;
|
||||
}
|
||||
|
||||
uint KROW_ACCUM_INDEX = 0;
|
||||
|
||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||
const bool LAST_TILE_ALIGNED =
|
||||
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||
|
||||
T4 thread_data_x4;
|
||||
T4 thread_data_y4;
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
|
||||
KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
|
||||
} else {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const device T* baseKThisHead =
|
||||
K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||
|
||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for(uint i = 0; i < 8; i++) {
|
||||
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
|
||||
KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// TODO: multiple query sequence length for speculative decoding
|
||||
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||
}
|
||||
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
||||
|
||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||
|
||||
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||
|
||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||
|
||||
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
||||
float threadAccum[ACCUM_PER_GROUP];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
|
||||
threadAccum[threadAccumIndex] = -INFINITY;
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t i = 0; i < P_VEC4; i++) {
|
||||
thread_data_x4 =
|
||||
T4(threadAccum[4 * i],
|
||||
threadAccum[4 * i + 1],
|
||||
threadAccum[4 * i + 2],
|
||||
threadAccum[4 * i + 3]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
thread_data_y4 = simd_sum(thread_data_x4);
|
||||
if (simd_lane_id == 0) {
|
||||
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
|
||||
}
|
||||
}
|
||||
|
||||
uint KROW_ACCUM_INDEX = 0;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||
float groupMax;
|
||||
float lse = 0.f;
|
||||
|
||||
T4 thread_data_x4;
|
||||
T4 thread_data_y4;
|
||||
if(!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||
constexpr const size_t ACCUM_ARRAY_LENGTH =
|
||||
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
|
||||
accum_array_iter++) {
|
||||
pvals[accum_array_iter] = float4(-INFINITY);
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
|
||||
float2 vals = smemPtrFlt2[simd_lane_id];
|
||||
vals *= params.INV_ALPHA;
|
||||
float maxval = max(vals.x, vals.y);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float2 expf_shifted = exp(vals - groupMax);
|
||||
float sumExpLocal = expf_shifted.x + expf_shifted.y;
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
|
||||
lse = log(tgroupExpSum);
|
||||
float2 local_p_hat = expf_shifted / tgroupExpSum;
|
||||
pvals[0].x = local_p_hat.x;
|
||||
pvals[0].y = local_p_hat.y;
|
||||
smemPtrFlt2[simd_lane_id] = float2(0.f);
|
||||
}
|
||||
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
|
||||
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
|
||||
|
||||
if (TILE_SIZE_LARGER_THAN_64) {
|
||||
float maxval = -INFINITY;
|
||||
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
||||
vals *= params.INV_ALPHA;
|
||||
pvals[i] = vals;
|
||||
maxval = fmax3(vals.x, vals.y, maxval);
|
||||
maxval = fmax3(vals.z, vals.w, maxval);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float sumExpLocal = 0.f;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = exp(pvals[i] - groupMax);
|
||||
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
lse = log(tgroupExpSum);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = pvals[i] / tgroupExpSum;
|
||||
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
|
||||
|
||||
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
|
||||
const size_t v_head_offset = kv_head_offset_factor * L * DK;
|
||||
|
||||
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
|
||||
device T* baseV = (device T*)V + v_offset;
|
||||
|
||||
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
||||
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint matrix_load_loop_iter = 0;
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||
|
||||
for (size_t tile_start = simd_group_id;
|
||||
tile_start < TILE_SIZE_CONST_DIV_8;
|
||||
tile_start += NSIMDGROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong simdgroup_matrix_offset =
|
||||
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
|
||||
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
ulong2 matrixOrigin =
|
||||
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
|
||||
matrix_load_loop_iter++;
|
||||
};
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
uint loop_iter = 0;
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||
float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
} else {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||
}
|
||||
|
||||
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
|
||||
(TILE_SIZE_CONST + 1) / 128;
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_iter = 0;
|
||||
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
|
||||
T row_sum = 0.f;
|
||||
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[i]);
|
||||
T val = dot(p_local, v_local);
|
||||
row_sum += val;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||
float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
||||
} else {
|
||||
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
||||
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
||||
constexpr const int ROWS_PER_ITER = 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
int32_t tile_start;
|
||||
for (tile_start =
|
||||
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
tile_start < MAX_START_ROW;
|
||||
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong2 matrixOrigin =
|
||||
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||
simdgroup_load(
|
||||
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(
|
||||
tmp,
|
||||
smemV,
|
||||
elemsPerRowSmem,
|
||||
matrixOriginSmem,
|
||||
/* transpose */ false);
|
||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
};
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t i = 0; i < P_VEC4; i++) {
|
||||
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
thread_data_y4 = simd_sum(thread_data_x4);
|
||||
if(simd_lane_id == 0) {
|
||||
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
|
||||
tile_start =
|
||||
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||
|
||||
const int32_t INT_L = int32_t(L);
|
||||
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
|
||||
row_index += NSIMDGROUPS) {
|
||||
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
const uint elems_per_row_gmem = DK;
|
||||
const uint col_index_v_gmem =
|
||||
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||
const uint row_index_v_gmem = row_index;
|
||||
|
||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||
const uint row_index_v_smem = simd_lane_id;
|
||||
|
||||
const uint scalar_offset_gmem =
|
||||
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||
const uint scalar_offset_smem =
|
||||
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||
smemV[scalar_offset_smem] = vdata;
|
||||
smem_col_index += NSIMDGROUPS;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float groupMax;
|
||||
float lse = 0.f;
|
||||
|
||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
|
||||
pvals[accum_array_iter] = float4(-INFINITY);
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
|
||||
float2 vals = smemPtrFlt2[simd_lane_id];
|
||||
vals *= params.INV_ALPHA;
|
||||
float maxval = max(vals.x, vals.y);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float2 expf_shifted = exp(vals - groupMax);
|
||||
float sumExpLocal = expf_shifted.x + expf_shifted.y;
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
|
||||
lse = log(tgroupExpSum);
|
||||
float2 local_p_hat = expf_shifted / tgroupExpSum;
|
||||
pvals[0].x = local_p_hat.x;
|
||||
pvals[0].y = local_p_hat.y;
|
||||
smemPtrFlt2[simd_lane_id] = float2(0.f);
|
||||
}
|
||||
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
|
||||
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
|
||||
|
||||
if (TILE_SIZE_LARGER_THAN_64) {
|
||||
float maxval = -INFINITY;
|
||||
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
||||
vals *= params.INV_ALPHA;
|
||||
pvals[i] = vals;
|
||||
maxval = fmax3(vals.x, vals.y, maxval);
|
||||
maxval = fmax3(vals.z, vals.w, maxval);
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
for (size_t smem_row_index = simd_group_id;
|
||||
smem_row_index < ROWS_PER_ITER;
|
||||
smem_row_index += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[smem_row_index] = float(row_sum);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
}
|
||||
|
||||
float sumExpLocal = 0.f;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = exp(pvals[i] - groupMax);
|
||||
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
lse = log(tgroupExpSum);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = pvals[i] / tgroupExpSum;
|
||||
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_count = 0;
|
||||
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
|
||||
row_index += NSIMDGROUPS) {
|
||||
T row_sum = 0.f;
|
||||
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
|
||||
tile_iters++) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local =
|
||||
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[tile_iters]);
|
||||
row_sum += dot(p_local, v_local);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
|
||||
float(row_sum);
|
||||
loop_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
|
||||
const size_t v_head_offset = kv_head_offset_factor * L * DK;
|
||||
if (simd_group_id == 0) {
|
||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||
device float* oPartialGmem =
|
||||
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||
oPartialGmemVec4[simd_lane_id] = vals;
|
||||
}
|
||||
|
||||
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
|
||||
device T* baseV = (device T*)V + v_offset;
|
||||
|
||||
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
||||
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint matrix_load_loop_iter = 0;
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||
|
||||
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
|
||||
matrix_load_loop_iter++;
|
||||
};
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
uint loop_iter = 0;
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_iter = 0;
|
||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
|
||||
T row_sum = 0.f;
|
||||
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[i]);
|
||||
T val = dot(p_local, v_local);
|
||||
row_sum += val;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
||||
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
||||
constexpr const int ROWS_PER_ITER = 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
int32_t tile_start;
|
||||
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
|
||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
};
|
||||
|
||||
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||
|
||||
const int32_t INT_L = int32_t(L);
|
||||
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
|
||||
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
const uint elems_per_row_gmem = DK;
|
||||
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||
const uint row_index_v_gmem = row_index;
|
||||
|
||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||
const uint row_index_v_smem = simd_lane_id;
|
||||
|
||||
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||
smemV[scalar_offset_smem] = vdata;
|
||||
smem_col_index += NSIMDGROUPS;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
for(size_t smem_row_index = simd_group_id;
|
||||
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[smem_row_index] = float(row_sum);
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_count = 0;
|
||||
for(size_t row_index = simd_group_id;
|
||||
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
|
||||
T row_sum = 0.f;
|
||||
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[tile_iters]);
|
||||
row_sum += dot(p_local, v_local);
|
||||
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
|
||||
loop_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(simd_group_id == 0) {
|
||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||
oPartialGmemVec4[simd_lane_id] = vals;
|
||||
}
|
||||
|
||||
if(simd_group_id == 0 && simd_lane_id == 0) {
|
||||
const uint tileIndex = tid.y;
|
||||
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
|
||||
p_lse[gmem_partial_scalar_offset] = lse;
|
||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||
}
|
||||
if (simd_group_id == 0 && simd_lane_id == 0) {
|
||||
const uint tileIndex = tid.y;
|
||||
const uint gmem_partial_scalar_offset =
|
||||
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
|
||||
tileIndex;
|
||||
p_lse[gmem_partial_scalar_offset] = lse;
|
||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
|
||||
const device itype *Q [[buffer(0)]], \
|
||||
const device itype *K [[buffer(1)]], \
|
||||
const device itype *V [[buffer(2)]], \
|
||||
const device uint64_t& L [[buffer(3)]], \
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
||||
device float* O_partials [[buffer(5)]], \
|
||||
device float* p_lse [[buffer(6)]], \
|
||||
device float* p_maxes [[buffer(7)]], \
|
||||
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
|
||||
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
|
||||
fast_inference_sdpa_compute_partials_template< \
|
||||
itype, \
|
||||
itype2, \
|
||||
itype4, \
|
||||
tile_size, \
|
||||
nsimdgroups>( \
|
||||
const device itype* Q [[buffer(0)]], \
|
||||
const device itype* K [[buffer(1)]], \
|
||||
const device itype* V [[buffer(2)]], \
|
||||
const device uint64_t& L [[buffer(3)]], \
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
||||
device float* O_partials [[buffer(5)]], \
|
||||
device float* p_lse [[buffer(6)]], \
|
||||
device float* p_maxes [[buffer(7)]], \
|
||||
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
|
||||
itype, itype2, itype4, tile_size) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, 4) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, 8) // clang-format on
|
||||
|
||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
512);
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
512);
|
||||
|
||||
template <typename T>
|
||||
void fast_inference_sdpa_reduce_tiles_template(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device T* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
constexpr const int DK = 128;
|
||||
const ulong offset_rows =
|
||||
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||
const device float* p_lse_row = p_lse + offset_rows;
|
||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||
// reserve some number of registers. this constitutes an assumption on max
|
||||
// value of KV TILES.
|
||||
constexpr const uint8_t reserve = 128;
|
||||
float p_lse_regs[reserve];
|
||||
float p_rowmax_regs[reserve];
|
||||
float weights[reserve];
|
||||
|
||||
constexpr const int DK = 128;
|
||||
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||
const device float* p_lse_row = p_lse + offset_rows;
|
||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
|
||||
constexpr const uint8_t reserve = 128;
|
||||
float p_lse_regs[reserve];
|
||||
float p_rowmax_regs[reserve];
|
||||
float weights[reserve];
|
||||
float true_max = -INFINITY;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
p_lse_regs[i] = float(*(p_lse_row + i));
|
||||
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
||||
true_max = fmax(p_rowmax_regs[i], true_max);
|
||||
weights[i] = exp(p_lse_regs[i]);
|
||||
}
|
||||
|
||||
float true_max = -INFINITY;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
p_lse_regs[i] = float(*(p_lse_row + i));
|
||||
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
||||
true_max = fmax(p_rowmax_regs[i], true_max);
|
||||
weights[i] = exp(p_lse_regs[i]);
|
||||
}
|
||||
float denom = 0.f;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
weights[i] *= exp(p_rowmax_regs[i] - true_max);
|
||||
denom += weights[i];
|
||||
}
|
||||
|
||||
float denom = 0.f;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
weights[i] *= exp(p_rowmax_regs[i]-true_max);
|
||||
denom += weights[i];
|
||||
}
|
||||
const device float* O_partials_with_offset = O_partials +
|
||||
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
|
||||
tid.x * DK * params.KV_TILES;
|
||||
|
||||
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
|
||||
|
||||
float o_value = 0.f;
|
||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
||||
float val = *(O_partials_with_offset + i * DK + lid.x);
|
||||
o_value += val * weights[i] / denom;
|
||||
}
|
||||
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
|
||||
O_gmem[lid.x] = T(o_value);
|
||||
return;
|
||||
float o_value = 0.f;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
float val = *(O_partials_with_offset + i * DK + lid.x);
|
||||
o_value += val * weights[i] / denom;
|
||||
}
|
||||
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
|
||||
O_gmem[lid.x] = T(o_value);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_float(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device float* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
|
||||
O, tid, lid);
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
fast_inference_sdpa_reduce_tiles_template<float>(
|
||||
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||
}
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_half(
|
||||
const device float *O_partials [[buffer(0)]],
|
||||
const device float *p_lse[[buffer(1)]],
|
||||
const device float *p_maxes [[buffer(2)]],
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device half* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
|
||||
O, tid, lid);
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
fast_inference_sdpa_reduce_tiles_template<half>(
|
||||
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ struct CumProd<bool> {
|
||||
}
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
x &= other;
|
||||
}
|
||||
@ -77,7 +77,7 @@ struct CumMax {
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
@ -100,7 +100,7 @@ struct CumMin {
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
@ -114,54 +114,60 @@ struct CumMin {
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T * input) {
|
||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = input[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] = input[i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
||||
inline void load_safe(
|
||||
U values[N_READS],
|
||||
const device T* input,
|
||||
int start,
|
||||
int total,
|
||||
U init) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] =
|
||||
(start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = (start + i < total) ? input[i] : init;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_unsafe(U values[N_READS], device U * out) {
|
||||
inline void write_unsafe(U values[N_READS], device U* out) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
|
||||
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + N_READS - i - 1 < total) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + i < total) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
@ -169,12 +175,17 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void contiguous_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@ -195,42 +206,51 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
U values[N_READS];
|
||||
threadgroup U simdgroup_sums[32];
|
||||
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
|
||||
// N_READS*lsize)
|
||||
// Read block
|
||||
// Compute inclusive scan of the block
|
||||
// Compute inclusive scan per thread
|
||||
// Compute exclusive scan of thread sums in simdgroup
|
||||
// Write simdgroup sums in SM
|
||||
// Compute exclusive scan of simdgroup sums
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
|
||||
// value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r*lsize*N_READS + lid*N_READS;
|
||||
uint offset = r * lsize * N_READS + lid * N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
||||
load_unsafe<T, U, N_READS, reverse>(
|
||||
values, in + axis_size - offset - N_READS);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values,
|
||||
in + axis_size - offset - N_READS,
|
||||
offset,
|
||||
axis_size,
|
||||
Op::init);
|
||||
}
|
||||
} else {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values, in + offset, offset, axis_size, Op::init);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an inclusive scan per thread
|
||||
for (int i=1; i<N_READS; i++) {
|
||||
values[i] = op(values[i], values[i-1]);
|
||||
for (int i = 1; i < N_READS; i++) {
|
||||
values[i] = op(values[i], values[i - 1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
|
||||
|
||||
// Write simdgroup_sums to SM
|
||||
if (simd_lane_id == simd_size - 1) {
|
||||
@ -246,7 +266,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute the output
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||
values[i] = op(values[i], prev_thread);
|
||||
@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
if (reverse) {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[axis_size-1] = Op::init;
|
||||
out[axis_size - 1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - 1 - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values,
|
||||
out + axis_size - offset - 1 - N_READS,
|
||||
offset + 1,
|
||||
axis_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
@ -284,26 +312,33 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset + 1, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[0] = values[N_READS-1];
|
||||
simdgroup_sums[0] = values[N_READS - 1];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
prefix = simdgroup_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void strided_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
const constant size_t & stride [[buffer(3)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
@ -311,10 +346,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
|
||||
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
@ -322,7 +357,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
|
||||
for (uint j=0; j<axis_size; j+=simd_size) {
|
||||
for (uint j = 0; j < axis_size; j += simd_size) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint check_index_y = index_y;
|
||||
@ -333,37 +368,43 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] =
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||
// simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size-1);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size - 1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||
values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -371,11 +412,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
@ -391,55 +432,60 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contiguous_scan_" #name)]] \
|
||||
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_contiguous_scan( \
|
||||
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
|
||||
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& axis_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("strided_scan_" #name)]] \
|
||||
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
const constant size_t & stride [[buffer(3)]], \
|
||||
uint2 gid [[thread_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
#define instantiate_strided_scan( \
|
||||
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("strided_scan_" #name)]] [[kernel]] void \
|
||||
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& axis_size [[buffer(2)]], \
|
||||
const constant size_t& stride [[buffer(3)]], \
|
||||
uint2 gid [[thread_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
|
||||
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
// clang-format off
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
|
||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
|
@ -13,67 +13,55 @@ using namespace metal;
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(
|
||||
idx_buffers[i][gid.y], out_shape[i]);
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter_1d_index( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||
\
|
||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
}
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter_1d_index( \
|
||||
const device T* updates [[buffer(1)]], \
|
||||
device mlx_atomic<T>* out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||
\
|
||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int *upd_shape [[buffer(3)]],
|
||||
const constant size_t *upd_strides [[buffer(4)]],
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int *out_shape [[buffer(7)]],
|
||||
const constant size_t *out_strides [[buffer(8)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
@ -97,142 +84,134 @@ METAL_FUNC void scatter_impl(
|
||||
out_idx += out_offset;
|
||||
}
|
||||
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
auto upd_idx =
|
||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
upd_shape, \
|
||||
upd_strides, \
|
||||
upd_ndim, \
|
||||
upd_size, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
out_ndim, \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter( \
|
||||
const device T* updates [[buffer(1)]], \
|
||||
device mlx_atomic<T>* out [[buffer(2)]], \
|
||||
const constant int* upd_shape [[buffer(3)]], \
|
||||
const constant size_t* upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int* out_shape [[buffer(7)]], \
|
||||
const constant size_t* out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int* idx_shapes [[buffer(11)]], \
|
||||
const constant size_t* idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||
\
|
||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
upd_shape, \
|
||||
upd_strides, \
|
||||
upd_ndim, \
|
||||
upd_size, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
out_ndim, \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
|
||||
#define make_scatter(n) \
|
||||
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
|
||||
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
#define make_scatter(n) \
|
||||
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
|
||||
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
make_scatter(2)
|
||||
make_scatter(3)
|
||||
make_scatter(4)
|
||||
make_scatter(5)
|
||||
make_scatter(6)
|
||||
make_scatter(7)
|
||||
make_scatter(8)
|
||||
make_scatter(9)
|
||||
make_scatter(10)
|
||||
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
|
||||
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
|
||||
make_scatter(9) make_scatter(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
|
||||
scatter<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t* updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||
const constant int* upd_shape [[buffer(3)]], \
|
||||
const constant size_t* upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int* out_shape [[buffer(7)]], \
|
||||
const constant size_t* out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int* idx_shapes [[buffer(11)]], \
|
||||
const constant size_t* idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
|
||||
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t* updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
// clang-format off
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
// clang-format off
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
// clang-format off
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
||||
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@ -198,17 +198,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
@ -220,16 +219,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax_precise(float16, half)
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@ -11,7 +11,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
// Based on GPU merge sort algorithm at
|
||||
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Thread-level sort
|
||||
@ -43,20 +44,18 @@ struct ThreadSort {
|
||||
static METAL_FUNC void sort(
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if(op(vals[j + 1], vals[j])) {
|
||||
for (short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if (op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -72,25 +71,25 @@ template <
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct BlockMergeSort {
|
||||
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
using thread_sort_t =
|
||||
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
static METAL_FUNC int merge_partition(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
short sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
short A_st = max(0, sort_md - B_sz);
|
||||
short A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
while (A_st < A_ed) {
|
||||
short md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
@ -98,7 +97,6 @@ struct BlockMergeSort {
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void merge_step(
|
||||
@ -110,12 +108,11 @@ struct BlockMergeSort {
|
||||
short B_sz,
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
short a_idx = 0;
|
||||
short b_idx = 0;
|
||||
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
auto a = As[a_idx];
|
||||
auto b = Bs[b_idx];
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
@ -126,7 +123,6 @@ struct BlockMergeSort {
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void sort(
|
||||
@ -134,32 +130,32 @@ struct BlockMergeSort {
|
||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||
int size_sorted_axis,
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Get thread location
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
|
||||
// Load from shared memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
thread_vals[i] = tgp_vals[idx + i];
|
||||
if(ARG_SORT) {
|
||||
if (ARG_SORT) {
|
||||
thread_idxs[i] = tgp_idxs[idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread sort
|
||||
if(idx < size_sorted_axis) {
|
||||
// Per thread sort
|
||||
if (idx < size_sorted_axis) {
|
||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Do merges using threadgroup memory
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||
merge_threads *= 2) {
|
||||
// Update threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
@ -167,7 +163,7 @@ struct BlockMergeSort {
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
|
||||
int sort_sz = N_PER_THREAD * merge_threads;
|
||||
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||
@ -185,16 +181,11 @@ struct BlockMergeSort {
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Find a partition of merge elements
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// of size N_PER_THREAD for each merge lane i
|
||||
// C = [Ci] is sorted
|
||||
// C = [Ci] is sorted
|
||||
int sort_md = N_PER_THREAD * merge_lane;
|
||||
int partition = merge_partition(
|
||||
As,
|
||||
Bs,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md);
|
||||
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||
|
||||
As += partition;
|
||||
Bs += sort_md - partition;
|
||||
@ -202,27 +193,20 @@ struct BlockMergeSort {
|
||||
A_sz -= partition;
|
||||
B_sz -= sort_md - partition;
|
||||
|
||||
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
const threadgroup idx_t* As_idx =
|
||||
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx =
|
||||
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
|
||||
// Merge starting at the partition and store results in thread registers
|
||||
merge_step(
|
||||
As,
|
||||
Bs,
|
||||
As_idx,
|
||||
Bs_idx,
|
||||
A_sz,
|
||||
B_sz,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Write out to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
@ -235,7 +219,7 @@ struct BlockMergeSort {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
@ -244,13 +228,13 @@ struct KernelMergeSort {
|
||||
using val_t = T;
|
||||
using idx_t = uint;
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
@ -263,15 +247,15 @@ struct KernelMergeSort {
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
if(ARG_SORT) {
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
}
|
||||
}
|
||||
@ -284,8 +268,8 @@ struct KernelMergeSort {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
|
||||
if(ARG_SORT) {
|
||||
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||
if (ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
@ -296,7 +280,7 @@ struct KernelMergeSort {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
@ -308,12 +292,12 @@ template <
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
if(ARG_SORT) {
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
@ -339,14 +323,13 @@ template <
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
constant constexpr const int zero_helper = 0;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
@ -360,8 +343,8 @@ template <
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
@ -369,7 +352,7 @@ template <
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
|
||||
if(ARG_SORT) {
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
@ -395,50 +378,55 @@ template <
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_block_sort( \
|
||||
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||
"_nc")]] [[kernel]] void \
|
||||
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
instantiate_block_sort( \
|
||||
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
instantiate_block_sort( \
|
||||
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_block_sort_bn(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256) \
|
||||
instantiate_block_sort_tn(itname, itype, 512)
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256) \
|
||||
instantiate_block_sort_tn(itname, itype, 512)
|
||||
|
||||
instantiate_block_sort_bn(uint8, uint8_t)
|
||||
instantiate_block_sort_bn(uint16, uint16_t)
|
||||
@ -448,35 +436,35 @@ instantiate_block_sort_bn(int16, int16_t)
|
||||
instantiate_block_sort_bn(int32, int32_t)
|
||||
instantiate_block_sort_bn(float16, half)
|
||||
instantiate_block_sort_bn(float32, float)
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
|
||||
// clang-format off
|
||||
#define instantiate_block_sort_long(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256)
|
||||
|
||||
instantiate_block_sort_long(uint64, uint64_t)
|
||||
instantiate_block_sort_long(int64, int64_t)
|
||||
instantiate_block_sort_long(int64, int64_t) // clang-format on
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multi block merge sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multi block merge sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
int base_idx = tid.x * N_PER_BLOCK;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
tgp_idxs[i] = idx;
|
||||
}
|
||||
|
||||
@ -508,9 +496,9 @@ struct KernelMultiBlockMergeSort {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
if (idx < size_sorted_axis) {
|
||||
out_vals[idx] = tgp_vals[i];
|
||||
out_idxs[idx] = tgp_idxs[i];
|
||||
}
|
||||
@ -523,18 +511,17 @@ struct KernelMultiBlockMergeSort {
|
||||
int A_sz,
|
||||
int B_sz,
|
||||
int sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
int A_st = max(0, sort_md - B_sz);
|
||||
int A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
while (A_st < A_ed) {
|
||||
int md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -563,8 +549,12 @@ template <
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
@ -575,12 +565,12 @@ template <
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
@ -592,7 +582,8 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
@ -601,21 +592,20 @@ template <
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
|
||||
block_partitions += tid.y * tgp_dims.x;
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
@ -627,14 +617,9 @@ template <
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st,
|
||||
dev_vals + B_st,
|
||||
A_ed - A_st,
|
||||
B_ed - B_st,
|
||||
partition_at);
|
||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
@ -644,7 +629,8 @@ template <
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_merge(
|
||||
const device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals_in [[buffer(1)]],
|
||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||
@ -655,20 +641,19 @@ template <
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
|
||||
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||
|
||||
block_partitions += tid.y * (num_tiles + 1);
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_out += tid.y * size_sorted_axis;
|
||||
dev_idxs_out += tid.y * size_sorted_axis;
|
||||
|
||||
@ -680,25 +665,29 @@ template <
|
||||
|
||||
int A_st = block_partitions[block_idx + 0];
|
||||
int A_ed = block_partitions[block_idx + 1];
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
|
||||
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||
int B_ed = min(
|
||||
size_sorted_axis,
|
||||
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
|
||||
if((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
|
||||
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||
}
|
||||
|
||||
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Load from global memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
if(idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
||||
if (idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||
: dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||
: dev_idxs_in[B_st + idx - A_sz];
|
||||
} else {
|
||||
thread_vals[i] = CompareOp::init;
|
||||
thread_idxs[i] = 0;
|
||||
@ -709,7 +698,7 @@ template <
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
tgp_vals[idx] = thread_vals[i];
|
||||
tgp_idxs[idx] = thread_idxs[i];
|
||||
@ -720,11 +709,7 @@ template <
|
||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||
|
||||
int A_st_local = block_sort_t::merge_partition(
|
||||
tgp_vals,
|
||||
tgp_vals + A_sz,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md_local);
|
||||
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||
int A_ed_local = A_sz;
|
||||
|
||||
int B_st_local = sort_md_local - A_st_local;
|
||||
@ -733,7 +718,7 @@ template <
|
||||
int A_sz_local = A_ed_local - A_st_local;
|
||||
int B_sz_local = B_ed_local - B_st_local;
|
||||
|
||||
// Do merge
|
||||
// Do merge
|
||||
block_sort_t::merge_step(
|
||||
tgp_vals + A_st_local,
|
||||
tgp_vals + A_ed_local + B_st_local,
|
||||
@ -745,61 +730,65 @@ template <
|
||||
thread_idxs);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Write output
|
||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
if (idx < size_sorted_axis) {
|
||||
dev_vals_out[idx] = tgp_vals[i];
|
||||
dev_idxs_out[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
#define instantiate_multi_block_sort( \
|
||||
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype * block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
||||
@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
|
||||
instantiate_multi_block_sort_base(int32, int32_t)
|
||||
instantiate_multi_block_sort_base(float16, half)
|
||||
instantiate_multi_block_sort_base(float32, float)
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||
|
||||
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||
instantiate_multi_block_sort_long(int64, int64_t)
|
||||
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
|
@ -4,21 +4,23 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
@ -28,12 +30,10 @@ template <typename T,
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)lid;
|
||||
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
@ -47,46 +47,64 @@ template <typename T,
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
|
||||
// Input loader
|
||||
|
||||
using loader_a_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>
|
||||
>
|
||||
>;
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>>>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_b>,
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
|
||||
>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
@ -99,12 +117,12 @@ template <typename T,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
@ -123,8 +141,10 @@ template <typename T,
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_a_t loader_a(
|
||||
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
@ -152,38 +172,53 @@ template <typename T,
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
#define instantiate_implicit_conv_2d( \
|
||||
name, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
channel_name, \
|
||||
n_channels, \
|
||||
filter_name, \
|
||||
small_filter) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
|
||||
"_filter_" #filter_name)]] [[kernel]] void \
|
||||
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -4,23 +4,25 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d_general(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
@ -33,9 +35,8 @@ template <typename T,
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
@ -49,15 +50,15 @@ template <typename T,
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t = Conv2DInputBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
@ -70,12 +71,12 @@ template <typename T,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
@ -103,13 +104,32 @@ template <typename T,
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
loader_a_t loader_a(
|
||||
A,
|
||||
As,
|
||||
offsets_a,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B,
|
||||
Bs,
|
||||
offsets_b,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -143,22 +163,24 @@ template <typename T,
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < mma_t::TM; i++) {
|
||||
|
||||
int cm = offset_m + i * mma_t::TM_stride;
|
||||
|
||||
int n = cm / jump_params->adj_out_hw;
|
||||
int hw = cm % jump_params->adj_out_hw;
|
||||
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
int oh =
|
||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow =
|
||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
|
||||
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
|
||||
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
int offset_cm = n * params->out_strides[0] +
|
||||
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
thread const auto& accum =
|
||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
@ -170,40 +192,42 @@ template <typename T,
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template \
|
||||
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
|
||||
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
@ -11,97 +11,126 @@ using namespace mlx::steel;
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *D [[buffer(3)]],
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
// Adjust for batch
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||
"_K_" #kname)]] [[kernel]] void \
|
||||
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -10,23 +10,24 @@ using namespace mlx::steel;
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
@ -34,243 +35,304 @@ template <typename T,
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel =
|
||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||
transpose_a, transpose_b,
|
||||
MN_aligned, K_aligned,
|
||||
AccumType, Epilogue>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Adjust for batch
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
C += batch_offsets.z;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
A_bstrides,
|
||||
B_bstrides,
|
||||
C_bstrides,
|
||||
params->batch_ndim);
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
C += batch_offsets.z;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
// Store results to device memory
|
||||
mma_op.store_result(
|
||||
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
mma_op.store_result(
|
||||
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned, \
|
||||
ep_name, \
|
||||
epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
|
||||
addmm< \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned, \
|
||||
float, \
|
||||
epilogue<itype, float>>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
const device itype* C [[buffer(2)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||
// clang-format off
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
@ -11,319 +11,378 @@ using namespace mlx::steel;
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
bool has_operand_mask=false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *D [[buffer(3)]],
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
bool has_operand_mask = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
block_masked_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device bool *out_mask [[buffer(10)]],
|
||||
const device bool *lhs_mask [[buffer(11)]],
|
||||
const device bool *rhs_mask [[buffer(12)]],
|
||||
const device bool* out_mask [[buffer(10)]],
|
||||
const device bool* lhs_mask [[buffer(11)]],
|
||||
const device bool* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
|
||||
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
|
||||
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
} else {
|
||||
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||
if(has_operand_mask) {
|
||||
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs =
|
||||
mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
tid.z,
|
||||
batch_shape,
|
||||
mask_strides_lhs,
|
||||
mask_strides_rhs,
|
||||
params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
} else {
|
||||
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||
if (has_operand_mask) {
|
||||
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
// Write zeros and return
|
||||
if(!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
// Write zeros and return
|
||||
if (!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for(short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for(short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for (short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(
|
||||
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(
|
||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[(params->K / BM) * mask_strides[5] +
|
||||
tid_x * mask_strides[4]])) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
// Load elements into threadgroup
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
bool M_aligned = (tgp_bm == BM);
|
||||
bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[(params->K / BM) * mask_strides[5] +
|
||||
tid_x * mask_strides[4]])) {
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
if (M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
bool M_aligned = (tgp_bm == BM);
|
||||
bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(!has_operand_mask ||
|
||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if(M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
|
||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
|
||||
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const device bool *out_mask [[buffer(10)]], \
|
||||
const device bool *lhs_mask [[buffer(11)]], \
|
||||
const device bool *rhs_mask [[buffer(12)]], \
|
||||
const constant int* mask_strides [[buffer(13)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned, \
|
||||
omname, \
|
||||
op_mask) \
|
||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname \
|
||||
"_op_mask_" #omname)]] [[kernel]] void \
|
||||
block_masked_gemm< \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned, \
|
||||
op_mask>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const device bool* out_mask [[buffer(10)]], \
|
||||
const device bool* lhs_mask [[buffer(11)]], \
|
||||
const device bool* rhs_mask [[buffer(12)]], \
|
||||
const constant int* mask_strides [[buffer(13)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||
|
@ -10,81 +10,95 @@ using namespace mlx::steel;
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device U *C [[buffer(2)]],
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
A += transpose_a ? (c_row + k_start * params->lda)
|
||||
: (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb)
|
||||
: (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) +
|
||||
(c_row * params->ldc + c_col);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
@ -95,37 +109,38 @@ template <typename T,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining =
|
||||
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
@ -137,69 +152,102 @@ template <typename T,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
}
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device otype *C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
gemm_splitk< \
|
||||
itype, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device otype* C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
@ -207,32 +255,31 @@ template <typename AccT,
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
|
||||
}
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT *C [[buffer(5)]],
|
||||
const device OutT* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
C += gid.x * fdc + gid.y * ldc;
|
||||
D += gid.x + gid.y * ldd;
|
||||
@ -241,40 +288,42 @@ template <typename AccT,
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype *C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||
"_" #aname)]] [[kernel]] void \
|
||||
gemm_splitk_accum<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||
"_axpby")]] [[kernel]] void \
|
||||
gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype* C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float);
|
||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
@ -3,9 +3,9 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_v(
|
||||
@ -65,7 +65,8 @@ template <typename T, typename Op>
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
|
||||
constant const size_t c_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto idx =
|
||||
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
@ -99,103 +102,104 @@ template <typename T, typename Op>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
auto idx =
|
||||
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
#define instantiate_ternary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_v<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
#define instantiate_ternary_v(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
ternary_op_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
ternary_op_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
ternary_op_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
ternary_op_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
// clang-format off
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
||||
instantiate_ternary_types(select, Select)
|
@ -22,44 +22,46 @@ template <typename T, typename Op>
|
||||
out[index] = Op()(in[idx]);
|
||||
}
|
||||
|
||||
#define instantiate_unary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_v<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
#define instantiate_unary_v(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_unary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_g<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
device const int* in_shape, \
|
||||
device const size_t* in_strides, \
|
||||
device const int& ndim, \
|
||||
#define instantiate_unary_g(name, type, op) \
|
||||
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
device const int* in_shape, \
|
||||
device const size_t* in_strides, \
|
||||
device const int& ndim, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_unary_all(name, tname, type, op) \
|
||||
instantiate_unary_v("v" #name #tname, type, op) \
|
||||
instantiate_unary_g("g" #name #tname, type, op)
|
||||
instantiate_unary_v("v" #name #tname, type, op) \
|
||||
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
|
||||
|
||||
#define instantiate_unary_float(name, op) \
|
||||
instantiate_unary_all(name, float16, half, op) \
|
||||
instantiate_unary_all(name, float32, float, op) \
|
||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
||||
// clang-format off
|
||||
#define instantiate_unary_float(name, op) \
|
||||
instantiate_unary_all(name, float16, half, op) \
|
||||
instantiate_unary_all(name, float32, float, op) \
|
||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
|
||||
|
||||
#define instantiate_unary_types(name, op) \
|
||||
instantiate_unary_all(name, bool_, bool, op) \
|
||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||
// clang-format off
|
||||
#define instantiate_unary_types(name, op) \
|
||||
instantiate_unary_all(name, bool_, bool, op) \
|
||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_unary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_unary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_unary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_unary_all(name, int8, int8_t, op) \
|
||||
instantiate_unary_all(name, int16, int16_t, op) \
|
||||
instantiate_unary_all(name, int32, int32_t, op) \
|
||||
instantiate_unary_all(name, int64, int64_t, op) \
|
||||
instantiate_unary_float(name, op)
|
||||
instantiate_unary_all(name, int8, int8_t, op) \
|
||||
instantiate_unary_all(name, int16, int16_t, op) \
|
||||
instantiate_unary_all(name, int32, int32_t, op) \
|
||||
instantiate_unary_all(name, int64, int64_t, op) \
|
||||
instantiate_unary_float(name, op) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_unary_types(abs, Abs)
|
||||
instantiate_unary_float(arccos, ArcCos)
|
||||
instantiate_unary_float(arccosh, ArcCosh)
|
||||
@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on
|
||||
|
@ -13,7 +13,7 @@ struct Device {
|
||||
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||
|
||||
Device(DeviceType type, int index = 0) : type(type), index(index){};
|
||||
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
||||
|
||||
DeviceType type;
|
||||
int index;
|
||||
|
@ -51,7 +51,7 @@ struct Dtype {
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {};
|
||||
constexpr operator Val() const {
|
||||
return val;
|
||||
};
|
||||
|
@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
class Event {
|
||||
public:
|
||||
Event(){};
|
||||
Event() {};
|
||||
|
||||
Event(const Stream& steam);
|
||||
|
||||
|
@ -12,7 +12,7 @@ class Custom : public Primitive {
|
||||
explicit Custom(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||
: Primitive(stream), fallback_(fallback){};
|
||||
: Primitive(stream), fallback_(fallback) {};
|
||||
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
@ -39,7 +39,7 @@ class RMSNorm : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -68,7 +68,7 @@ class RMSNormVJP : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -91,7 +91,7 @@ class LayerNorm : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -120,7 +120,7 @@ class LayerNormVJP : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
: Custom(stream, fallback), eps_(eps) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -154,7 +154,7 @@ class RoPE : public Custom {
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
forward_(forward){};
|
||||
forward_(forward) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -189,7 +189,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask)
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
|
168
mlx/primitives.h
168
mlx/primitives.h
@ -154,7 +154,7 @@ class UnaryPrimitive : public Primitive {
|
||||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -171,7 +171,7 @@ class Abs : public UnaryPrimitive {
|
||||
|
||||
class Add : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Add(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Add(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -189,7 +189,7 @@ class Add : public UnaryPrimitive {
|
||||
class AddMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AddMM(Stream stream, float alpha, float beta)
|
||||
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta){};
|
||||
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -213,7 +213,7 @@ class AddMM : public UnaryPrimitive {
|
||||
class Arange : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Arange(Stream stream, double start, double stop, double step)
|
||||
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){};
|
||||
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -231,7 +231,7 @@ class Arange : public UnaryPrimitive {
|
||||
|
||||
class ArcCos : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcCos(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -248,7 +248,7 @@ class ArcCos : public UnaryPrimitive {
|
||||
|
||||
class ArcCosh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -265,7 +265,7 @@ class ArcCosh : public UnaryPrimitive {
|
||||
|
||||
class ArcSin : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcSin(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -282,7 +282,7 @@ class ArcSin : public UnaryPrimitive {
|
||||
|
||||
class ArcSinh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -299,7 +299,7 @@ class ArcSinh : public UnaryPrimitive {
|
||||
|
||||
class ArcTan : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTan(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -316,7 +316,7 @@ class ArcTan : public UnaryPrimitive {
|
||||
|
||||
class ArcTanh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -334,7 +334,7 @@ class ArcTanh : public UnaryPrimitive {
|
||||
class ArgPartition : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArgPartition(Stream stream, int kth, int axis)
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -359,7 +359,7 @@ class ArgReduce : public UnaryPrimitive {
|
||||
};
|
||||
|
||||
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){};
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -380,7 +380,7 @@ class ArgReduce : public UnaryPrimitive {
|
||||
class ArgSort : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArgSort(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis){};
|
||||
: UnaryPrimitive(stream), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -399,7 +399,7 @@ class ArgSort : public UnaryPrimitive {
|
||||
class AsType : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AsType(Stream stream, Dtype dtype)
|
||||
: UnaryPrimitive(stream), dtype_(dtype){};
|
||||
: UnaryPrimitive(stream), dtype_(dtype) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -426,7 +426,7 @@ class AsStrided : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
shape_(std::move(shape)),
|
||||
strides_(std::move(strides)),
|
||||
offset_(offset){};
|
||||
offset_(offset) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -448,7 +448,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
enum Op { And, Or, Xor, LeftShift, RightShift };
|
||||
|
||||
explicit BitwiseBinary(Stream stream, Op op)
|
||||
: UnaryPrimitive(stream), op_(op){};
|
||||
: UnaryPrimitive(stream), op_(op) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -465,7 +465,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
class BlockMaskedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||
: UnaryPrimitive(stream), block_size_(block_size){};
|
||||
: UnaryPrimitive(stream), block_size_(block_size) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -488,7 +488,7 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape){};
|
||||
: UnaryPrimitive(stream), shape_(shape) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -506,7 +506,7 @@ class Broadcast : public UnaryPrimitive {
|
||||
|
||||
class Ceil : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Ceil(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Ceil(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -567,7 +567,7 @@ class Compiled : public Primitive {
|
||||
class Concatenate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Concatenate(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis){};
|
||||
: UnaryPrimitive(stream), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -599,7 +599,7 @@ class Convolution : public UnaryPrimitive {
|
||||
kernel_dilation_(kernel_dilation),
|
||||
input_dilation_(input_dilation),
|
||||
groups_(groups),
|
||||
flip_(flip){};
|
||||
flip_(flip) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -626,7 +626,7 @@ class Convolution : public UnaryPrimitive {
|
||||
|
||||
class Copy : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Copy(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Copy(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -643,7 +643,7 @@ class Copy : public UnaryPrimitive {
|
||||
|
||||
class Cos : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Cos(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Cos(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -660,7 +660,7 @@ class Cos : public UnaryPrimitive {
|
||||
|
||||
class Cosh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Cosh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Cosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -731,7 +731,7 @@ class Depends : public Primitive {
|
||||
|
||||
class Divide : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Divide(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Divide(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -748,7 +748,7 @@ class Divide : public UnaryPrimitive {
|
||||
|
||||
class DivMod : public Primitive {
|
||||
public:
|
||||
explicit DivMod(Stream stream) : Primitive(stream){};
|
||||
explicit DivMod(Stream stream) : Primitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
@ -770,7 +770,7 @@ class DivMod : public Primitive {
|
||||
|
||||
class Select : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -787,7 +787,7 @@ class Select : public UnaryPrimitive {
|
||||
|
||||
class Remainder : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -805,7 +805,7 @@ class Remainder : public UnaryPrimitive {
|
||||
class Equal : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Equal(Stream stream, bool equal_nan = false)
|
||||
: UnaryPrimitive(stream), equal_nan_(equal_nan){};
|
||||
: UnaryPrimitive(stream), equal_nan_(equal_nan) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -830,7 +830,7 @@ class Equal : public UnaryPrimitive {
|
||||
|
||||
class Erf : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Erf(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Erf(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -847,7 +847,7 @@ class Erf : public UnaryPrimitive {
|
||||
|
||||
class ErfInv : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ErfInv(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -864,7 +864,7 @@ class ErfInv : public UnaryPrimitive {
|
||||
|
||||
class Exp : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Exp(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Exp(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -881,7 +881,7 @@ class Exp : public UnaryPrimitive {
|
||||
|
||||
class Expm1 : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Expm1(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Expm1(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -902,7 +902,7 @@ class FFT : public UnaryPrimitive {
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real)
|
||||
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){};
|
||||
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -923,7 +923,7 @@ class FFT : public UnaryPrimitive {
|
||||
|
||||
class Floor : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Floor(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Floor(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -940,7 +940,7 @@ class Floor : public UnaryPrimitive {
|
||||
|
||||
class Full : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Full(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Full(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -960,7 +960,7 @@ class Gather : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes)
|
||||
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
|
||||
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -978,7 +978,7 @@ class Gather : public UnaryPrimitive {
|
||||
|
||||
class Greater : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Greater(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Greater(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -995,7 +995,7 @@ class Greater : public UnaryPrimitive {
|
||||
|
||||
class GreaterEqual : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1012,7 +1012,7 @@ class GreaterEqual : public UnaryPrimitive {
|
||||
|
||||
class Less : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Less(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Less(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1029,7 +1029,7 @@ class Less : public UnaryPrimitive {
|
||||
|
||||
class LessEqual : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LessEqual(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1054,7 +1054,7 @@ class Load : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
reader_(reader),
|
||||
offset_(offset),
|
||||
swap_endianness_(swap_endianness){};
|
||||
swap_endianness_(swap_endianness) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1073,7 +1073,7 @@ class Log : public UnaryPrimitive {
|
||||
enum Base { two, ten, e };
|
||||
|
||||
explicit Log(Stream stream, Base base)
|
||||
: UnaryPrimitive(stream), base_(base){};
|
||||
: UnaryPrimitive(stream), base_(base) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1104,7 +1104,7 @@ class Log : public UnaryPrimitive {
|
||||
|
||||
class Log1p : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Log1p(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Log1p(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1120,7 +1120,7 @@ class Log1p : public UnaryPrimitive {
|
||||
|
||||
class LogicalNot : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1137,7 +1137,7 @@ class LogicalNot : public UnaryPrimitive {
|
||||
|
||||
class LogicalAnd : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1154,7 +1154,7 @@ class LogicalAnd : public UnaryPrimitive {
|
||||
|
||||
class LogicalOr : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1171,7 +1171,7 @@ class LogicalOr : public UnaryPrimitive {
|
||||
|
||||
class LogAddExp : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1188,7 +1188,7 @@ class LogAddExp : public UnaryPrimitive {
|
||||
|
||||
class Matmul : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Matmul(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1206,7 +1206,7 @@ class Matmul : public UnaryPrimitive {
|
||||
|
||||
class Maximum : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Maximum(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Maximum(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1223,7 +1223,7 @@ class Maximum : public UnaryPrimitive {
|
||||
|
||||
class Minimum : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Minimum(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Minimum(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1240,7 +1240,7 @@ class Minimum : public UnaryPrimitive {
|
||||
|
||||
class Multiply : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Multiply(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Multiply(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1257,7 +1257,7 @@ class Multiply : public UnaryPrimitive {
|
||||
|
||||
class Negative : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Negative(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Negative(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1274,7 +1274,7 @@ class Negative : public UnaryPrimitive {
|
||||
|
||||
class NotEqual : public UnaryPrimitive {
|
||||
public:
|
||||
explicit NotEqual(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1330,7 +1330,7 @@ class Pad : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(axes),
|
||||
low_pad_size_(low_pad_size),
|
||||
high_pad_size_(high_pad_size){};
|
||||
high_pad_size_(high_pad_size) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1351,7 +1351,7 @@ class Pad : public UnaryPrimitive {
|
||||
class Partition : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Partition(Stream stream, int kth, int axis)
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
|
||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1371,7 +1371,7 @@ class Partition : public UnaryPrimitive {
|
||||
|
||||
class Power : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Power(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Power(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1396,7 +1396,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose){};
|
||||
transpose_(transpose) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1417,7 +1417,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
class RandomBits : public UnaryPrimitive {
|
||||
public:
|
||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
||||
: UnaryPrimitive(stream), shape_(shape), width_(width){};
|
||||
: UnaryPrimitive(stream), shape_(shape), width_(width) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1436,7 +1436,7 @@ class RandomBits : public UnaryPrimitive {
|
||||
class Reshape : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape){};
|
||||
: UnaryPrimitive(stream), shape_(shape) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1468,7 +1468,7 @@ class Reduce : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
ReduceType reduce_type,
|
||||
const std::vector<int>& axes)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1517,7 +1517,7 @@ class Reduce : public UnaryPrimitive {
|
||||
|
||||
class Round : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Round(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Round(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1546,7 +1546,7 @@ class Scan : public UnaryPrimitive {
|
||||
reduce_type_(reduce_type),
|
||||
axis_(axis),
|
||||
reverse_(reverse),
|
||||
inclusive_(inclusive){};
|
||||
inclusive_(inclusive) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1591,7 +1591,7 @@ class Scatter : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
ReduceType reduce_type,
|
||||
const std::vector<int>& axes)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1626,7 +1626,7 @@ class Scatter : public UnaryPrimitive {
|
||||
|
||||
class Sigmoid : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1643,7 +1643,7 @@ class Sigmoid : public UnaryPrimitive {
|
||||
|
||||
class Sign : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sign(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Sign(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1660,7 +1660,7 @@ class Sign : public UnaryPrimitive {
|
||||
|
||||
class Sin : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sin(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Sin(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1677,7 +1677,7 @@ class Sin : public UnaryPrimitive {
|
||||
|
||||
class Sinh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sinh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Sinh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1702,7 +1702,7 @@ class Slice : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
start_indices_(start_indices),
|
||||
end_indices_(end_indices),
|
||||
strides_(strides){};
|
||||
strides_(strides) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1738,7 +1738,7 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
: UnaryPrimitive(stream),
|
||||
start_indices_(start_indices),
|
||||
end_indices_(end_indices),
|
||||
strides_(strides){};
|
||||
strides_(strides) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1761,7 +1761,7 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
class Softmax : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Softmax(Stream stream, bool precise)
|
||||
: UnaryPrimitive(stream), precise_(precise){};
|
||||
: UnaryPrimitive(stream), precise_(precise) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1781,7 +1781,7 @@ class Softmax : public UnaryPrimitive {
|
||||
class Sort : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sort(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis){};
|
||||
: UnaryPrimitive(stream), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1801,7 +1801,7 @@ class Sort : public UnaryPrimitive {
|
||||
class Split : public Primitive {
|
||||
public:
|
||||
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
||||
: Primitive(stream), indices_(indices), axis_(axis){};
|
||||
: Primitive(stream), indices_(indices), axis_(axis) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
@ -1822,7 +1822,7 @@ class Split : public Primitive {
|
||||
|
||||
class Square : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Square(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Square(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1840,7 +1840,7 @@ class Square : public UnaryPrimitive {
|
||||
class Sqrt : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sqrt(Stream stream, bool recip = false)
|
||||
: UnaryPrimitive(stream), recip_(recip){};
|
||||
: UnaryPrimitive(stream), recip_(recip) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1865,7 +1865,7 @@ class Sqrt : public UnaryPrimitive {
|
||||
|
||||
class StopGradient : public UnaryPrimitive {
|
||||
public:
|
||||
explicit StopGradient(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1881,7 +1881,7 @@ class StopGradient : public UnaryPrimitive {
|
||||
|
||||
class Subtract : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Subtract(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Subtract(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1898,7 +1898,7 @@ class Subtract : public UnaryPrimitive {
|
||||
|
||||
class Tan : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Tan(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Tan(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1915,7 +1915,7 @@ class Tan : public UnaryPrimitive {
|
||||
|
||||
class Tanh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Tanh(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Tanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1932,7 +1932,7 @@ class Tanh : public UnaryPrimitive {
|
||||
|
||||
class Uniform : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1948,7 +1948,7 @@ class Uniform : public UnaryPrimitive {
|
||||
class Transpose : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
||||
: UnaryPrimitive(stream), axes_(axes){};
|
||||
: UnaryPrimitive(stream), axes_(axes) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1967,7 +1967,7 @@ class Transpose : public UnaryPrimitive {
|
||||
/* QR Factorization primitive. */
|
||||
class QRF : public Primitive {
|
||||
public:
|
||||
explicit QRF(Stream stream) : Primitive(stream){};
|
||||
explicit QRF(Stream stream) : Primitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
@ -1983,7 +1983,7 @@ class QRF : public Primitive {
|
||||
/* SVD primitive. */
|
||||
class SVD : public Primitive {
|
||||
public:
|
||||
explicit SVD(Stream stream) : Primitive(stream){};
|
||||
explicit SVD(Stream stream) : Primitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
@ -2000,7 +2000,7 @@ class SVD : public Primitive {
|
||||
/* Matrix inversion primitive. */
|
||||
class Inverse : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Inverse(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Inverse(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
||||
|
@ -22,7 +22,7 @@ namespace mlx::core {
|
||||
* for synchronizing with the main thread. */
|
||||
class Synchronizer : public Primitive {
|
||||
public:
|
||||
explicit Synchronizer(Stream stream) : Primitive(stream){};
|
||||
explicit Synchronizer(Stream stream) : Primitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
|
@ -14,8 +14,8 @@ inline constexpr bool can_convert_to_complex128 =
|
||||
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
||||
|
||||
struct complex128_t : public std::complex<double> {
|
||||
complex128_t(double v, double u) : std::complex<double>(v, u){};
|
||||
complex128_t(std::complex<double> v) : std::complex<double>(v){};
|
||||
complex128_t(double v, double u) : std::complex<double>(v, u) {};
|
||||
complex128_t(std::complex<double> v) : std::complex<double>(v) {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
@ -32,8 +32,8 @@ inline constexpr bool can_convert_to_complex64 =
|
||||
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
||||
|
||||
struct complex64_t : public std::complex<float> {
|
||||
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
||||
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
||||
complex64_t(float v, float u) : std::complex<float>(v, u) {};
|
||||
complex64_t(std::complex<float> v) : std::complex<float>(v) {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
|
Loading…
Reference in New Issue
Block a user