diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae9db3839..10b7290ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.3 + rev: v18.1.4 hooks: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.3.0 + rev: 24.4.2 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 1fb705d61..a2c34123e 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -33,7 +33,7 @@ array axpby( class Axpby : public Primitive { public: explicit Axpby(Stream stream, float alpha, float beta) - : Primitive(stream), alpha_(alpha), beta_(beta){}; + : Primitive(stream), alpha_(alpha), beta_(beta) {}; /** * A primitive must know how to evaluate itself on the CPU/GPU diff --git a/examples/extensions/axpby/axpby.metal b/examples/extensions/axpby/axpby.metal index 03b373c99..503ad7444 100644 --- a/examples/extensions/axpby/axpby.metal +++ b/examples/extensions/axpby/axpby.metal @@ -19,7 +19,7 @@ template uint index [[thread_position_in_grid]]) { auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim); - out[index] = + out[index] = static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; } @@ -31,30 +31,30 @@ template constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], uint index [[thread_position_in_grid]]) { - out[index] = + out[index] = static_cast(alpha) * x[index] + static_cast(beta) * y[index]; } -#define instantiate_axpby(type_name, type) \ - template [[host_name("axpby_general_" #type_name)]] \ - [[kernel]] void axpby_general( \ - device const type* x [[buffer(0)]], \ - device const type* y [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - constant const float& alpha [[buffer(3)]], \ - constant const float& beta [[buffer(4)]], \ - constant const int* shape [[buffer(5)]], \ - constant const size_t* x_strides [[buffer(6)]], \ - constant const size_t* y_strides [[buffer(7)]], \ - constant const int& ndim [[buffer(8)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("axpby_contiguous_" #type_name)]] \ - [[kernel]] void axpby_contiguous( \ - device const type* x [[buffer(0)]], \ - device const type* y [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - constant const float& alpha [[buffer(3)]], \ - constant const float& beta [[buffer(4)]], \ +#define instantiate_axpby(type_name, type) \ + template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \ + axpby_general( \ + device const type* x [[buffer(0)]], \ + device const type* y [[buffer(1)]], \ + device type* out [[buffer(2)]], \ + constant const float& alpha [[buffer(3)]], \ + constant const float& beta [[buffer(4)]], \ + constant const int* shape [[buffer(5)]], \ + constant const size_t* x_strides [[buffer(6)]], \ + constant const size_t* y_strides [[buffer(7)]], \ + constant const int& ndim [[buffer(8)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \ + axpby_contiguous( \ + device const type* x [[buffer(0)]], \ + device const type* y [[buffer(1)]], \ + device type* out [[buffer(2)]], \ + constant const float& alpha [[buffer(3)]], \ + constant const float& beta [[buffer(4)]], \ uint index [[thread_position_in_grid]]); instantiate_axpby(float32, float); diff --git a/mlx/allocator.h b/mlx/allocator.h index 1061d6cce..42dd7e180 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -14,7 +14,7 @@ class Buffer { void* ptr_; public: - Buffer(void* ptr) : ptr_(ptr){}; + Buffer(void* ptr) : ptr_(ptr) {}; // Get the raw data pointer from the buffer void* raw_ptr(); diff --git a/mlx/array.h b/mlx/array.h index 42ebca35d..b576ff2c6 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -209,7 +209,7 @@ class array { allocator::Buffer buffer; deleter_t d; Data(allocator::Buffer buffer, deleter_t d = allocator::free) - : buffer(buffer), d(d){}; + : buffer(buffer), d(d) {}; // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 029b2cc92..9e5518af6 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -38,7 +38,7 @@ using MTLFCList = struct CommandEncoder { CommandEncoder(MTL::ComputeCommandEncoder* enc) - : enc(enc), concurrent(false){}; + : enc(enc), concurrent(false) {}; CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index d7ba56aae..b896e4226 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -11,22 +11,22 @@ template out[index] = start + index * step; } -#define instantiate_arange(tname, type) \ - template [[host_name("arange" #tname)]] \ - [[kernel]] void arange( \ - constant const type& start, \ - constant const type& step, \ - device type* out, \ - uint index [[thread_position_in_grid]]); +#define instantiate_arange(tname, type) \ + template [[host_name("arange" #tname)]] [[kernel]] void arange( \ + constant const type& start, \ + constant const type& step, \ + device type* out, \ + uint index [[thread_position_in_grid]]); -instantiate_arange(uint8, uint8_t) +// clang-format off +instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) -instantiate_arange(uint32, uint32_t) +instantiate_arange(uint32, uint32_t) instantiate_arange(uint64, uint64_t) -instantiate_arange(int8, int8_t) +instantiate_arange(int8, int8_t) instantiate_arange(int16, int16_t) instantiate_arange(int32, int32_t) instantiate_arange(int64, int64_t) instantiate_arange(float16, half) instantiate_arange(float32, float) -instantiate_arange(bfloat16, bfloat16_t) \ No newline at end of file +instantiate_arange(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index f24a32ce8..6adeb12ab 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -18,7 +18,8 @@ struct ArgMin { static constexpr constant U init = Limits::max; IndexValPair reduce(IndexValPair best, IndexValPair current) { - if (best.val > current.val || (best.val == current.val && best.index > current.index)) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { return current; } else { return best; @@ -26,11 +27,12 @@ struct ArgMin { } template - IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { - for (int i=0; i + reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i = 0; i < N; i++) { if (vals[i] < best.val) { best.val = vals[i]; - best.index = offset+i; + best.index = offset + i; } } return best; @@ -42,7 +44,8 @@ struct ArgMax { static constexpr constant U init = Limits::min; IndexValPair reduce(IndexValPair best, IndexValPair current) { - if (best.val < current.val || (best.val == current.val && best.index > current.index)) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { return current; } else { return best; @@ -50,11 +53,12 @@ struct ArgMax { } template - IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { - for (int i=0; i + reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i = 0; i < N; i++) { if (vals[i] > best.val) { best.val = vals[i]; - best.index = offset+i; + best.index = offset + i; } } return best; @@ -64,19 +68,16 @@ struct ArgMax { template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { return IndexValPair{ - simd_shuffle_down(data.index, delta), - simd_shuffle_down(data.val, delta) - }; + simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; } - template [[kernel]] void arg_reduce_general( - const device T *in [[buffer(0)]], - device uint32_t *out [[buffer(1)]], - const device int *shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], + const device T* in [[buffer(0)]], + device uint32_t* out [[buffer(1)]], + const device int* shape [[buffer(2)]], + const device size_t* in_strides [[buffer(3)]], + const device size_t* out_strides [[buffer(4)]], const device size_t& ndim [[buffer(5)]], const device size_t& axis_stride [[buffer(6)]], const device size_t& axis_size [[buffer(7)]], @@ -86,7 +87,6 @@ template uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - // Shapes and strides *do not* contain the reduction axis. The reduction size // and stride are provided in axis_stride and axis_size. // @@ -113,13 +113,13 @@ template threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { // Read the current value - uint32_t current_index = r*lsize*N_READS + lid*N_READS; + uint32_t current_index = r * lsize * N_READS + lid * N_READS; uint32_t offset = current_index; - const device T * current_in = in + in_idx + current_index * axis_stride; + const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; - for (int i=0; i // need to reduce across the thread group. // First per simd reduction. - for (uint offset=simd_size/2; offset>0; offset/=2) { + for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } @@ -149,7 +149,7 @@ template if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } - for (uint offset=simd_size/2; offset>0; offset/=2) { + for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } @@ -161,24 +161,25 @@ template } #define instantiate_arg_reduce_helper(name, itype, op) \ - template [[host_name(name)]] \ - [[kernel]] void arg_reduce_general, 4>( \ - const device itype *in [[buffer(0)]], \ - device uint32_t * out [[buffer(1)]], \ - const device int *shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - const device size_t& ndim [[buffer(5)]], \ - const device size_t& axis_stride [[buffer(6)]], \ - const device size_t& axis_size [[buffer(7)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]], \ + template [[host_name(name)]] [[kernel]] void \ + arg_reduce_general, 4>( \ + const device itype* in [[buffer(0)]], \ + device uint32_t* out [[buffer(1)]], \ + const device int* shape [[buffer(2)]], \ + const device size_t* in_strides [[buffer(3)]], \ + const device size_t* out_strides [[buffer(4)]], \ + const device size_t& ndim [[buffer(5)]], \ + const device size_t& axis_stride [[buffer(6)]], \ + const device size_t& axis_size [[buffer(7)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_arg_reduce(name, itype) \ +// clang-format off +#define instantiate_arg_reduce(name, itype) \ instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) @@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t) instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float32, float) -instantiate_arg_reduce(bfloat16, bfloat16_t) \ No newline at end of file +instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 7674a13f1..8dba35958 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -77,7 +77,8 @@ template uint3 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } @@ -92,7 +93,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); c[out_idx] = Op()(a[idx.x], b[idx.y]); } @@ -112,114 +114,118 @@ template c[out_idx] = Op()(a[idx.x], b[idx.y]); } -#define instantiate_binary(name, itype, otype, op, bopt) \ - template [[host_name(name)]] \ - [[kernel]] void binary_op_##bopt( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - uint index [[thread_position_in_grid]]); +#define instantiate_binary(name, itype, otype, op, bopt) \ + template \ + [[host_name(name)]] [[kernel]] void binary_op_##bopt( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + uint index [[thread_position_in_grid]]); #define instantiate_binary_g_dim(name, itype, otype, op, dims) \ - template [[host_name(name "_" #dims)]] \ - [[kernel]] void binary_op_g_nd( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ + template [[host_name(name "_" #dims)]] [[kernel]] void \ + binary_op_g_nd( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); #define instantiate_binary_g_nd(name, itype, otype, op) \ - template [[host_name(name "_1")]] \ - [[kernel]] void binary_op_g_nd1( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t& a_stride, \ - constant const size_t& b_stride, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] \ - [[kernel]] void binary_op_g_nd2( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] \ - [[kernel]] void binary_op_g_nd3( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_binary_g_dim(name, itype, otype, op, 4) \ - instantiate_binary_g_dim(name, itype, otype, op, 5) + template [[host_name(name "_1")]] [[kernel]] void \ + binary_op_g_nd1( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t& a_stride, \ + constant const size_t& b_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] [[kernel]] void \ + binary_op_g_nd2( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] [[kernel]] void \ + binary_op_g_nd3( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_binary_g_dim(name, itype, otype, op, 4) \ + instantiate_binary_g_dim(name, itype, otype, op, 5) - -#define instantiate_binary_g(name, itype, otype, op) \ - template [[host_name(name)]] \ - [[kernel]] void binary_op_g( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ +#define instantiate_binary_g(name, itype, otype, op) \ + template [[host_name(name)]] [[kernel]] void binary_op_g( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); +// clang-format off #define instantiate_binary_all(name, tname, itype, otype, op) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ - instantiate_binary_g("g" #name #tname, itype, otype, op) \ - instantiate_binary_g_nd("g" #name #tname, itype, otype, op) + instantiate_binary_g("g" #name #tname, itype, otype, op) \ + instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on -#define instantiate_binary_integer(name, op) \ - instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ +// clang-format off +#define instantiate_binary_integer(name, op) \ + instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ - instantiate_binary_all(name, int8, int8_t, int8_t, op) \ - instantiate_binary_all(name, int16, int16_t, int16_t, op) \ - instantiate_binary_all(name, int32, int32_t, int32_t, op) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op) \ + instantiate_binary_all(name, int8, int8_t, int8_t, op) \ + instantiate_binary_all(name, int16, int16_t, int16_t, op) \ + instantiate_binary_all(name, int32, int32_t, int32_t, op) \ + instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on -#define instantiate_binary_float(name, op) \ - instantiate_binary_all(name, float16, half, half, op) \ +// clang-format off +#define instantiate_binary_float(name, op) \ + instantiate_binary_all(name, float16, half, half, op) \ instantiate_binary_all(name, float32, float, float, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on -#define instantiate_binary_types(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ - instantiate_binary_integer(name, op) \ +// clang-format off +#define instantiate_binary_types(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_integer(name, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ - instantiate_binary_float(name, op) + instantiate_binary_float(name, op) // clang-format on -#define instantiate_binary_types_bool(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ - instantiate_binary_all(name, uint8, uint8_t, bool, op) \ - instantiate_binary_all(name, uint16, uint16_t, bool, op) \ - instantiate_binary_all(name, uint32, uint32_t, bool, op) \ - instantiate_binary_all(name, uint64, uint64_t, bool, op) \ - instantiate_binary_all(name, int8, int8_t, bool, op) \ - instantiate_binary_all(name, int16, int16_t, bool, op) \ - instantiate_binary_all(name, int32, int32_t, bool, op) \ - instantiate_binary_all(name, int64, int64_t, bool, op) \ - instantiate_binary_all(name, float16, half, bool, op) \ - instantiate_binary_all(name, float32, float, bool, op) \ +// clang-format off +#define instantiate_binary_types_bool(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_all(name, uint8, uint8_t, bool, op) \ + instantiate_binary_all(name, uint16, uint16_t, bool, op) \ + instantiate_binary_all(name, uint32, uint32_t, bool, op) \ + instantiate_binary_all(name, uint64, uint64_t, bool, op) \ + instantiate_binary_all(name, int8, int8_t, bool, op) \ + instantiate_binary_all(name, int16, int16_t, bool, op) \ + instantiate_binary_all(name, int32, int32_t, bool, op) \ + instantiate_binary_all(name, int64, int64_t, bool, op) \ + instantiate_binary_all(name, float16, half, bool, op) \ + instantiate_binary_all(name, float32, float, bool, op) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ - instantiate_binary_all(name, complex64, complex64_t, bool, op) + instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on +// clang-format off instantiate_binary_types(add, Add) instantiate_binary_types(div, Divide) instantiate_binary_types_bool(eq, Equal) @@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr) instantiate_binary_integer(bitwise_xor, BitwiseXor) instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor) instantiate_binary_integer(left_shift, LeftShift) -instantiate_binary_integer(right_shift, RightShift) +instantiate_binary_integer(right_shift, RightShift) // clang-format on diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 245ced024..c192561d7 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -3,28 +3,42 @@ #include #include -#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" struct FloorDivide { - template T operator()(T x, T y) { return x / y; } - template <> float operator()(float x, float y) { return trunc(x / y); } - template <> half operator()(half x, half y) { return trunc(x / y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); } + template + T operator()(T x, T y) { + return x / y; + } + template <> + float operator()(float x, float y) { + return trunc(x / y); + } + template <> + half operator()(half x, half y) { + return trunc(x / y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return trunc(x / y); + } }; struct Remainder { template - metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { return x % y; } template - metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { auto r = x % y; if (r != 0 && (r < 0 != y < 0)) { r += y; } - return r; + return r; } template metal::enable_if_t, T> operator()(T x, T y) { @@ -32,10 +46,11 @@ struct Remainder { if (r != 0 && (r < 0 != y < 0)) { r += y; } - return r; + return r; } - template <> complex64_t operator()(complex64_t x, complex64_t y) { - return x % y; + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; } }; @@ -50,7 +65,6 @@ template d[index] = Op2()(a[0], b[0]); } - template [[kernel]] void binary_op_ss( device const T* a, @@ -139,7 +153,8 @@ template uint3 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); c[out_idx] = Op1()(a[a_idx], b[b_idx]); d[out_idx] = Op2()(a[a_idx], b[b_idx]); } @@ -156,7 +171,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); c[out_idx] = Op1()(a[idx.x], b[idx.y]); d[out_idx] = Op2()(a[idx.x], b[idx.y]); } @@ -180,99 +196,102 @@ template } #define instantiate_binary(name, itype, otype, op1, op2, bopt) \ - template [[host_name(name)]] \ - [[kernel]] void binary_op_##bopt( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - uint index [[thread_position_in_grid]]); + template [[host_name(name)]] [[kernel]] void \ + binary_op_##bopt( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + uint index [[thread_position_in_grid]]); #define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ - template [[host_name(name "_" #dims)]] \ - [[kernel]] void binary_op_g_nd( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ + template [[host_name(name "_" #dims)]] [[kernel]] void \ + binary_op_g_nd( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); +// clang-format off #define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ - template [[host_name(name "_1")]] \ - [[kernel]] void binary_op_g_nd1( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t& a_stride, \ - constant const size_t& b_stride, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] \ - [[kernel]] void binary_op_g_nd2( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] \ - [[kernel]] void binary_op_g_nd3( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ - instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) - + template [[host_name(name "_1")]] [[kernel]] void \ + binary_op_g_nd1( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t& a_stride, \ + constant const size_t& b_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] [[kernel]] void \ + binary_op_g_nd2( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] [[kernel]] void \ + binary_op_g_nd3( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ + instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on #define instantiate_binary_g(name, itype, otype, op1, op2) \ - template [[host_name(name)]] \ - [[kernel]] void binary_op_g( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ + template [[host_name(name)]] [[kernel]] void \ + binary_op_g( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); +// clang-format off #define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ - instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ - instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) + instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ + instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on -#define instantiate_binary_float(name, op1, op2) \ - instantiate_binary_all(name, float16, half, half, op1, op2) \ +// clang-format off +#define instantiate_binary_float(name, op1, op2) \ + instantiate_binary_all(name, float16, half, half, op1, op2) \ instantiate_binary_all(name, float32, float, float, op1, op2) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on -#define instantiate_binary_types(name, op1, op2) \ - instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ - instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ - instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ - instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ - instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ - instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ - instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ - instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ +// clang-format off +#define instantiate_binary_types(name, op1, op2) \ + instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ + instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ + instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ + instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ + instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ + instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ + instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ + instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ + instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ instantiate_binary_float(name, op1, op2) -instantiate_binary_types(divmod, FloorDivide, Remainder) +instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index 9cb27c5a3..df69bfe9f 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -22,7 +22,7 @@ struct complex64_t { float imag; // Constructors - constexpr complex64_t(float real, float imag) : real(real), imag(imag){}; + constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; // Conversions to complex64_t template < diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 563002997..4c65a3677 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -1,13 +1,11 @@ // Copyright © 2023-2024 Apple Inc. -#include #include #include #include - -#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" #define MLX_MTL_CONST static constant constexpr const @@ -23,14 +21,15 @@ template device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { - int filter_size = params->C; - for(short i = 0; i < N; i++) filter_size *= params->wS[i]; + for (short i = 0; i < N; i++) + filter_size *= params->wS[i]; int out_pixels = 1; - for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; + for (short i = 0; i < N; i++) + out_pixels *= params->oS[i]; - // Set out + // Set out out += gid.z * filter_size + gid.y * (params->C); // Coordinates in input @@ -46,11 +45,11 @@ template bool valid = n < params->N; - // Unroll dimensions + // Unroll dimensions for (int i = N - 1; i >= 0; --i) { int os_ = (oS % params->oS[i]); int ws_ = (wS % params->wS[i]); - + ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; @@ -64,10 +63,10 @@ template wS /= params->wS[i]; } - if(valid) { + if (valid) { size_t in_offset = n * params->in_strides[0]; - for(int i = 0; i < N; ++i) { + for (int i = 0; i < N; ++i) { in_offset += is[i] * params->in_strides[i + 1]; } @@ -85,12 +84,13 @@ template device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { - int filter_size = params->C; - for(short i = 0; i < N; i++) filter_size *= params->wS[i]; + for (short i = 0; i < N; i++) + filter_size *= params->wS[i]; int out_pixels = 1; - for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; + for (short i = 0; i < N; i++) + out_pixels *= params->oS[i]; // Set out out += gid.z * filter_size + gid.x * (filter_size / params->C); @@ -128,10 +128,10 @@ template out += ws_ * params->str[i]; } - if(valid) { + if (valid) { size_t in_offset = n * params->in_strides[0]; - for(int i = 0; i < N; ++i) { + for (int i = 0; i < N; ++i) { in_offset += is[i] * params->in_strides[i + 1]; } @@ -141,24 +141,24 @@ template } } -#define instantiate_naive_unfold_nd(name, itype, n) \ - template [[host_name("naive_unfold_nd_" #name "_" #n)]] \ - [[kernel]] void naive_unfold_Nd( \ - const device itype* in [[buffer(0)]], \ - device itype* out [[buffer(1)]], \ - const constant MLXConvParams* params [[buffer(2)]], \ - uint3 gid [[thread_position_in_grid]]); \ - template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \ - [[kernel]] void naive_unfold_transpose_Nd( \ - const device itype* in [[buffer(0)]], \ - device itype* out [[buffer(1)]], \ - const constant MLXConvParams* params [[buffer(2)]], \ - uint3 gid [[thread_position_in_grid]]); +#define instantiate_naive_unfold_nd(name, itype, n) \ + template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \ + naive_unfold_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); \ + template \ + [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \ + naive_unfold_transpose_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); -#define instantiate_naive_unfold_nd_dims(name, itype) \ - instantiate_naive_unfold_nd(name, itype, 1) \ - instantiate_naive_unfold_nd(name, itype, 2) \ - instantiate_naive_unfold_nd(name, itype, 3) +#define instantiate_naive_unfold_nd_dims(name, itype) \ + instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \ + name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3) instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float16, half); @@ -168,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); /// Slow and naive conv2d kernels /////////////////////////////////////////////////////////////////////////////// -template +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const int BC = 16> [[kernel]] void naive_conv_2d( const device T* in [[buffer(0)]], const device T* wt [[buffer(1)]], @@ -183,7 +184,6 @@ template = 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; - in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0); + in_local[m] = valid + ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] + : T(0); } // Load weight for (int n = 0; n < TN; ++n) { int o = out_o + n; - wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] + - h * params.wt_strides[1] + - w * params.wt_strides[2] + c] : T(0); + wt_local[n] = o < params.O + ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] + + w * params.wt_strides[2] + c] + : T(0); } // Accumulate - for(int m = 0; m < TM; ++m) { - for(int n = 0; n < TN; ++n) { + for (int m = 0; m < TM; ++m) { + for (int n = 0; n < TN; ++n) { out_local[m * TN + n] += in_local[m] * wt_local[n]; } } - } } } - for(int m = 0; m < TM; ++m) { - for(int n = 0; n < TN; ++n) { - if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O) - out[out_h[m] * params.out_strides[1] + - out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n]; + for (int m = 0; m < TM; ++m) { + for (int n = 0; n < TN; ++n) { + if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && + (out_o + n) < params.O) + out[out_h[m] * params.out_strides[1] + + out_w[m] * params.out_strides[2] + out_o + n] = + out_local[m * TN + n]; } } - } // Instantiations -#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ - template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ - [[kernel]] void naive_conv_2d( \ - const device itype* in [[buffer(0)]], \ - const device itype* wt [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ +#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ + template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \ + "_tn" #tn)]] [[kernel]] void \ + naive_conv_2d( \ + const device itype* in [[buffer(0)]], \ + const device itype* wt [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant MLXConvParams<2>& params [[buffer(3)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_naive_conv_2d_blocks(name, itype) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) + instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ + instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) instantiate_naive_conv_2d_blocks(float32, float); instantiate_naive_conv_2d_blocks(float16, half); @@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); /////////////////////////////////////////////////////////////////////////////// template -struct WinogradTransforms { - -}; +struct WinogradTransforms {}; template <> struct WinogradTransforms<6, 3, 8> { @@ -287,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> { MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - { 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, - { 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, - {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, - { 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, - { 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, - { 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, - {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, - { 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, + {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, + {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, + {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, + {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, + {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, + {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, }; MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - { 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, - { 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, - { 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, - { 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, - { 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, - { 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, - { 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, - { 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, + {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, + {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, + {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, + {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, + {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, + {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, }; MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - { 1.00, 0.00, 0.00}, - { -2.0/9.00, -2.0/9.00, -2.0/9.00}, - { -2.0/9.00, 2.0/9.00, -2.0/9.00}, - { 1.0/90.0, 1.0/45.0, 2.0/45.0}, - { 1.0/90.0, -1.0/45.0, 2.0/45.0}, - { 32.0/45.0, 16.0/45.0, 8.0/45.0}, - { 32.0/45.0, -16.0/45.0, 8.0/45.0}, - { 0.00, 0.00, 1.00}, + {1.00, 0.00, 0.00}, + {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00}, + {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00}, + {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0}, + {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0}, + {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0}, + {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0}, + {0.00, 0.00, 1.00}, }; }; @@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; -template -[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform( +template +[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void +winograd_conv_2d_weight_transform( const device T* wt_in [[buffer(0)]], device T* wt_out [[buffer(1)]], const constant int& C [[buffer(2)]], @@ -337,7 +334,6 @@ template ; // Get lane position in simdgroup @@ -357,35 +353,37 @@ template g; - g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); - g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); + g.thread_elements()[0] = + sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); + g.thread_elements()[1] = + sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); simdgroup_matrix g_out = (G * g) * Gt; wt_out_0[c * O] = g_out.thread_elements()[0]; @@ -396,27 +394,23 @@ template (\ - const device itype* wt_in [[buffer(0)]],\ - device itype* wt_out [[buffer(1)]],\ - const constant int& C [[buffer(2)]],\ - const constant int& O [[buffer(3)]],\ - uint tid [[threadgroup_position_in_grid]],\ - uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + template [[host_name("winograd_conv_2d_weight_transform_" #name \ + "_bc" #bc)]] [[kernel]] void \ + winograd_conv_2d_weight_transform( \ + const device itype* wt_in [[buffer(0)]], \ + device itype* wt_out [[buffer(1)]], \ + const constant int& C [[buffer(2)]], \ + const constant int& O [[buffer(3)]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform( +template +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +winograd_conv_2d_input_transform( const device T* inp_in [[buffer(0)]], device T* inp_out [[buffer(1)]], const constant MLXConvParams<2>& params [[buffer(2)]], @@ -425,7 +419,6 @@ template ; @@ -456,46 +449,48 @@ template I; I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[1] = Is[sm][sn + 1][c]; @@ -509,28 +504,24 @@ template (\ - const device itype* inp_in [[buffer(0)]],\ - device itype* inp_out [[buffer(1)]],\ - const constant MLXConvParams<2>& params [[buffer(2)]],\ - uint3 tid [[threadgroup_position_in_grid]],\ - uint3 lid [[thread_position_in_threadgroup]],\ - uint3 tgp_per_grid [[threadgroups_per_grid]],\ - uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + template [[host_name("winograd_conv_2d_input_transform_" #name \ + "_bc" #bc)]] [[kernel]] void \ + winograd_conv_2d_input_transform( \ + const device itype* inp_in [[buffer(0)]], \ + device itype* inp_out [[buffer(1)]], \ + const constant MLXConvParams<2>& params [[buffer(2)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_per_grid [[threadgroups_per_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform( +template +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +winograd_conv_2d_output_transform( const device T* out_in [[buffer(0)]], device T* out_out [[buffer(1)]], const constant MLXConvParams<2>& params [[buffer(2)]], @@ -539,7 +530,6 @@ template ; @@ -572,57 +562,59 @@ template O_mat; O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[1] = out_in_1[c]; simdgroup_matrix O_out = (Bt * (O_mat * B)); - if((sm < M) && (sn < M)) { + if ((sm < M) && (sn < M)) { Os[sm][sn][c] = O_out.thread_elements()[0]; } - if((sm < M) && ((sn + 1) < M)) { + if ((sm < M) && ((sn + 1) < M)) { Os[sm][sn + 1][c] = O_out.thread_elements()[1]; } } threadgroup_barrier(mem_flags::mem_threadgroup); // Read out from shared memory - for(int h = 0; h < TH; h++) { - for(int w = 0; w < TW; w++) { - if(jump_in[h][w] >= 0) { + for (int h = 0; h < TH; h++) { + for (int w = 0; w < TW; w++) { + if (jump_in[h][w] >= 0) { device T* out_ptr = out_out + jump_in[h][w]; - for(int c = simd_lane_id; c < BO; c += 32) { + for (int c = simd_lane_id; c < BO; c += 32) { out_ptr[c] = Os[kh + h][kw + w][c]; } } @@ -633,25 +625,27 @@ template (\ - const device itype* out_in [[buffer(0)]],\ - device itype* out_out [[buffer(1)]],\ - const constant MLXConvParams<2>& params [[buffer(2)]],\ - uint3 tid [[threadgroup_position_in_grid]],\ - uint3 lid [[thread_position_in_threadgroup]],\ - uint3 tgp_per_grid [[threadgroups_per_grid]],\ - uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + template [[host_name("winograd_conv_2d_output_transform_" #name \ + "_bo" #bo)]] [[kernel]] void \ + winograd_conv_2d_output_transform( \ + const device itype* out_in [[buffer(0)]], \ + device itype* out_out [[buffer(1)]], \ + const constant MLXConvParams<2>& params [[buffer(2)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_per_grid [[threadgroups_per_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); -#define instantiate_winograd_conv_2d(name, itype) \ +// clang-format off +#define instantiate_winograd_conv_2d(name, itype) \ instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ - instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ - instantiate_winograd_conv_2d_output_transform(name, itype, 32) + instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ + instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on +// clang-format off instantiate_winograd_conv_2d(float32, float); -instantiate_winograd_conv_2d(float16, half); \ No newline at end of file +instantiate_winograd_conv_2d(float16, half); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 4cfc3b68f..01518144b 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -49,7 +49,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } @@ -62,7 +63,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } @@ -76,7 +78,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } @@ -143,116 +146,110 @@ template dst[dst_idx] = static_cast(src[src_idx]); } -#define instantiate_copy(name, itype, otype, ctype) \ - template [[host_name(name)]] \ - [[kernel]] void copy_##ctype( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ +#define instantiate_copy(name, itype, otype, ctype) \ + template [[host_name(name)]] [[kernel]] void copy_##ctype( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ uint index [[thread_position_in_grid]]); -#define instantiate_copy_g_dim(name, itype, otype, dims) \ - template [[host_name(name "_" #dims)]] \ - [[kernel]] void copy_g_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name "_" #dims)]] \ - [[kernel]] void copy_gg_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ +#define instantiate_copy_g_dim(name, itype, otype, dims) \ + template [[host_name(name "_" #dims)]] [[kernel]] void \ + copy_g_nd( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name "_" #dims)]] [[kernel]] void \ + copy_gg_nd( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ uint3 index [[thread_position_in_grid]]); +#define instantiate_copy_g_nd(name, itype, otype) \ + template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name "_1")]] [[kernel]] void \ + copy_gg_nd1( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ + constant const int64_t& dst_stride [[buffer(4)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("g" name "_2")]] [[kernel]] void \ + copy_gg_nd2( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + uint2 index [[thread_position_in_grid]]); \ + template [[host_name("g" name "_3")]] [[kernel]] void \ + copy_gg_nd3( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + uint3 index [[thread_position_in_grid]]); \ + instantiate_copy_g_dim(name, itype, otype, 4) \ + instantiate_copy_g_dim(name, itype, otype, 5) -#define instantiate_copy_g_nd(name, itype, otype) \ - template [[host_name(name "_1")]] \ - [[kernel]] void copy_g_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] \ - [[kernel]] void copy_g_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] \ - [[kernel]] void copy_g_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name "_1")]] \ - [[kernel]] void copy_gg_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - constant const int64_t& dst_stride [[buffer(4)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g" name "_2")]] \ - [[kernel]] void copy_gg_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint2 index [[thread_position_in_grid]]); \ - template [[host_name("g" name "_3")]] \ - [[kernel]] void copy_gg_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint3 index [[thread_position_in_grid]]); \ - instantiate_copy_g_dim(name, itype, otype, 4) \ - instantiate_copy_g_dim(name, itype, otype, 5) - - -#define instantiate_copy_g(name, itype, otype) \ - template [[host_name(name)]] \ - [[kernel]] void copy_g( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int& ndim [[buffer(5)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name)]] \ - [[kernel]] void copy_gg( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - constant const int& ndim [[buffer(5)]], \ +#define instantiate_copy_g(name, itype, otype) \ + template [[host_name(name)]] [[kernel]] void copy_g( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int& ndim [[buffer(5)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name)]] [[kernel]] void copy_gg( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + constant const int& ndim [[buffer(5)]], \ uint3 index [[thread_position_in_grid]]); -#define instantiate_copy_all(tname, itype, otype) \ +// clang-format off +#define instantiate_copy_all(tname, itype, otype) \ instantiate_copy("scopy" #tname, itype, otype, s) \ instantiate_copy("vcopy" #tname, itype, otype, v) \ - instantiate_copy_g("gcopy" #tname, itype, otype) \ - instantiate_copy_g_nd("gcopy" #tname, itype, otype) + instantiate_copy_g("gcopy" #tname, itype, otype) \ + instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on -#define instantiate_copy_itype(itname, itype) \ - instantiate_copy_all(itname ##bool_, itype, bool) \ - instantiate_copy_all(itname ##uint8, itype, uint8_t) \ - instantiate_copy_all(itname ##uint16, itype, uint16_t) \ - instantiate_copy_all(itname ##uint32, itype, uint32_t) \ - instantiate_copy_all(itname ##uint64, itype, uint64_t) \ - instantiate_copy_all(itname ##int8, itype, int8_t) \ - instantiate_copy_all(itname ##int16, itype, int16_t) \ - instantiate_copy_all(itname ##int32, itype, int32_t) \ - instantiate_copy_all(itname ##int64, itype, int64_t) \ - instantiate_copy_all(itname ##float16, itype, half) \ - instantiate_copy_all(itname ##float32, itype, float) \ +// clang-format off +#define instantiate_copy_itype(itname, itype) \ + instantiate_copy_all(itname ##bool_, itype, bool) \ + instantiate_copy_all(itname ##uint8, itype, uint8_t) \ + instantiate_copy_all(itname ##uint16, itype, uint16_t) \ + instantiate_copy_all(itname ##uint32, itype, uint32_t) \ + instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_all(itname ##int8, itype, int8_t) \ + instantiate_copy_all(itname ##int16, itype, int16_t) \ + instantiate_copy_all(itname ##int32, itype, int32_t) \ + instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_all(itname ##float16, itype, half) \ + instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_all(itname ##complex64, itype, complex64_t) @@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t) instantiate_copy_itype(float16, half) instantiate_copy_itype(float32, float) instantiate_copy_itype(bfloat16, bfloat16_t) -instantiate_copy_itype(complex64, complex64_t) +instantiate_copy_itype(complex64, complex64_t) // clang-format on diff --git a/mlx/backend/metal/kernels/fft.metal b/mlx/backend/metal/kernels/fft.metal index 25ceaab18..66dc0d22b 100644 --- a/mlx/backend/metal/kernels/fft.metal +++ b/mlx/backend/metal/kernels/fft.metal @@ -6,9 +6,8 @@ // - VkFFT (https://github.com/DTolm/VkFFT) // - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) -#include #include - +#include #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" @@ -23,7 +22,7 @@ float2 complex_mul(float2 a, float2 b) { } float2 get_twiddle(int k, int p) { - float theta = -1.0f * k * M_PI_F / (2*p); + float theta = -1.0f * k * M_PI_F / (2 * p); float2 twiddle; twiddle.x = metal::fast::cos(theta); @@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) { } // single threaded radix2 implemetation -void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { +void radix2( + int i, + int p, + int m, + threadgroup float2* read_buf, + threadgroup float2* write_buf) { float2 x_0 = read_buf[i]; float2 x_1 = read_buf[i + m]; @@ -53,11 +57,16 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float } // single threaded radix4 implemetation -void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { +void radix4( + int i, + int p, + int m, + threadgroup float2* read_buf, + threadgroup float2* write_buf) { float2 x_0 = read_buf[i]; float2 x_1 = read_buf[i + m]; - float2 x_2 = read_buf[i + 2*m]; - float2 x_3 = read_buf[i + 3*m]; + float2 x_2 = read_buf[i + 2 * m]; + float2 x_3 = read_buf[i + 3 * m]; // The index within this sub-DFT int k = i & (p - 1); @@ -90,11 +99,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float write_buf[j] = y_0; write_buf[j + p] = y_1; - write_buf[j + 2*p] = y_2; - write_buf[j + 3*p] = y_3; + write_buf[j + 2 * p] = y_2; + write_buf[j + 3 * p] = y_3; } - // Each FFT is computed entirely in shared GPU memory. // // N is decomposed into radix-2 and radix-4 DFTs: @@ -107,11 +115,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float // steps at compile time for a ~20% performance boost. template [[kernel]] void fft( - const device float2 *in [[buffer(0)]], - device float2 * out [[buffer(1)]], + const device float2* in [[buffer(0)]], + device float2* out [[buffer(1)]], uint3 thread_position_in_grid [[thread_position_in_grid]], uint3 threads_per_grid [[threads_per_grid]]) { - // Index of the DFT in batch int batch_idx = thread_position_in_grid.x * n; // The index in the DFT we're working on @@ -132,16 +139,16 @@ template // Copy input into shared memory shared_in[i] = in[batch_idx + i]; shared_in[i + m] = in[batch_idx + i + m]; - shared_in[i + 2*m] = in[batch_idx + i + 2*m]; - shared_in[i + 3*m] = in[batch_idx + i + 3*m]; + shared_in[i + 2 * m] = in[batch_idx + i + 2 * m]; + shared_in[i + 3 * m] = in[batch_idx + i + 3 * m]; threadgroup_barrier(mem_flags::mem_threadgroup); int p = 1; for (size_t r = 0; r < radix_2_steps; r++) { - radix2(i, p, m*2, read_buf, write_buf); - radix2(i + m, p, m*2, read_buf, write_buf); + radix2(i, p, m * 2, read_buf, write_buf); + radix2(i + m, p, m * 2, read_buf, write_buf); p *= 2; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -167,29 +174,26 @@ template // Copy shared memory to output out[batch_idx + i] = read_buf[i]; out[batch_idx + i + m] = read_buf[i + m]; - out[batch_idx + i + 2*m] = read_buf[i + 2*m]; - out[batch_idx + i + 3*m] = read_buf[i + 3*m]; + out[batch_idx + i + 2 * m] = read_buf[i + 2 * m]; + out[batch_idx + i + 3 * m] = read_buf[i + 3 * m]; } -#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \ - template [[host_name("fft_" #name)]] \ - [[kernel]] void fft( \ - const device float2* in [[buffer(0)]], \ - device float2* out [[buffer(1)]], \ - uint3 thread_position_in_grid [[thread_position_in_grid]], \ - uint3 threads_per_grid [[threads_per_grid]]); - +#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \ + template [[host_name("fft_" #name)]] [[kernel]] void \ + fft( \ + const device float2* in [[buffer(0)]], \ + device float2* out [[buffer(1)]], \ + uint3 thread_position_in_grid [[thread_position_in_grid]], \ + uint3 threads_per_grid [[threads_per_grid]]); // Explicitly define kernels for each power of 2. +// clang-format off instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1) -instantiate_fft(8, 8, 1, 1) -instantiate_fft(16, 16, 0, 2) -instantiate_fft(32, 32, 1, 2) -instantiate_fft(64, 64, 0, 3) -instantiate_fft(128, 128, 1, 3) -instantiate_fft(256, 256, 0, 4) +instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2) +instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3) +instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4) instantiate_fft(512, 512, 1, 4) instantiate_fft(1024, 1024, 0, 5) // 2048 is the max that will fit into 32KB of threadgroup memory. // TODO: implement 4 step FFT for larger n. -instantiate_fft(2048, 2048, 1, 5) +instantiate_fft(2048, 2048, 1, 5) // clang-format on diff --git a/mlx/backend/metal/kernels/gather.metal b/mlx/backend/metal/kernels/gather.metal index 793b2af62..f8e4fbb87 100644 --- a/mlx/backend/metal/kernels/gather.metal +++ b/mlx/backend/metal/kernels/gather.metal @@ -14,17 +14,16 @@ using namespace metal; template METAL_FUNC void gather_impl( - const device T *src [[buffer(0)]], - device T *out [[buffer(1)]], - const constant int *src_shape [[buffer(2)]], - const constant size_t *src_strides [[buffer(3)]], + const device T* src [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant size_t* src_strides [[buffer(3)]], const constant size_t& src_ndim [[buffer(4)]], - const constant int *slice_sizes [[buffer(5)]], - const constant int *axes [[buffer(6)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], const thread Indices& indices, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto ind_idx = index.x; auto ind_offset = index.y; @@ -43,93 +42,78 @@ METAL_FUNC void gather_impl( indices.ndim); } auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], src_shape[ax]); + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); src_idx += idx_val * src_strides[ax]; } - auto src_offset = elem_to_loc( - ind_offset, slice_sizes, src_strides, src_ndim); + auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim); size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; out[out_idx] = src[src_offset + src_idx]; - } -#define make_gather_impl(IDX_ARG, IDX_ARR) \ -template \ -[[kernel]] void gather( \ - const device T *src [[buffer(0)]], \ - device T *out [[buffer(1)]], \ - const constant int *src_shape [[buffer(2)]], \ - const constant size_t *src_strides [[buffer(3)]], \ - const constant size_t& src_ndim [[buffer(4)]], \ - const constant int *slice_sizes [[buffer(5)]], \ - const constant int *axes [[buffer(6)]], \ - const constant int *idx_shapes [[buffer(7)]], \ - const constant size_t *idx_strides [[buffer(8)]], \ - const constant int& idx_ndim [[buffer(9)]], \ - IDX_ARG(IdxT) \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]) { \ - \ - Indices idxs{ \ - {{IDX_ARR()}}, \ - idx_shapes, \ - idx_strides, \ - idx_ndim}; \ - \ - return gather_impl( \ - src, \ - out, \ - src_shape, \ - src_strides, \ - src_ndim, \ - slice_sizes, \ - axes, \ - idxs, \ - index, \ - grid_dim); \ -} +#define make_gather_impl(IDX_ARG, IDX_ARR) \ + template \ + [[kernel]] void gather( \ + const device T* src [[buffer(0)]], \ + device T* out [[buffer(1)]], \ + const constant int* src_shape [[buffer(2)]], \ + const constant size_t* src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int* slice_sizes [[buffer(5)]], \ + const constant int* axes [[buffer(6)]], \ + const constant int* idx_shapes [[buffer(7)]], \ + const constant size_t* idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]) { \ + Indices idxs{ \ + {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \ + \ + return gather_impl( \ + src, \ + out, \ + src_shape, \ + src_strides, \ + src_ndim, \ + slice_sizes, \ + axes, \ + idxs, \ + index, \ + grid_dim); \ + } -#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) +#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n) -make_gather(0) -make_gather(1) -make_gather(2) -make_gather(3) -make_gather(4) -make_gather(5) -make_gather(6) -make_gather(7) -make_gather(8) -make_gather(9) -make_gather(10) +make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4) + make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9) + make_gather(10) ///////////////////////////////////////////////////////////////////// // Gather instantiations ///////////////////////////////////////////////////////////////////// -#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ -template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ -[[kernel]] void gather( \ - const device src_t *src [[buffer(0)]], \ - device src_t *out [[buffer(1)]], \ - const constant int *src_shape [[buffer(2)]], \ - const constant size_t *src_strides [[buffer(3)]], \ - const constant size_t& src_ndim [[buffer(4)]], \ - const constant int *slice_sizes [[buffer(5)]], \ - const constant int *axes [[buffer(6)]], \ - const constant int *idx_shapes [[buffer(7)]], \ - const constant size_t *idx_strides [[buffer(8)]], \ - const constant int& idx_ndim [[buffer(9)]], \ - IDX_ARG(idx_t) \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); +#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ + template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \ + gather( \ + const device src_t* src [[buffer(0)]], \ + device src_t* out [[buffer(1)]], \ + const constant int* src_shape [[buffer(2)]], \ + const constant size_t* src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int* slice_sizes [[buffer(5)]], \ + const constant int* axes [[buffer(6)]], \ + const constant int* idx_shapes [[buffer(7)]], \ + const constant size_t* idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); +// clang-format off #define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ - instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) + instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on +// clang-format off #define instantiate_gather4(name, src_t, idx_t, nidx) \ instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ @@ -148,29 +132,31 @@ instantiate_gather4("int32", int32_t, bool, 0) instantiate_gather4("int64", int64_t, bool, 0) instantiate_gather4("float16", half, bool, 0) instantiate_gather4("float32", float, bool, 0) -instantiate_gather4("bfloat16", bfloat16_t, bool, 0) +instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on +// clang-format off #define instantiate_gather3(name, src_type, ind_type) \ - instantiate_gather4(name, src_type, ind_type, 1) \ - instantiate_gather4(name, src_type, ind_type, 2) \ - instantiate_gather4(name, src_type, ind_type, 3) \ - instantiate_gather4(name, src_type, ind_type, 4) \ - instantiate_gather4(name, src_type, ind_type, 5) \ - instantiate_gather4(name, src_type, ind_type, 6) \ - instantiate_gather4(name, src_type, ind_type, 7) \ - instantiate_gather4(name, src_type, ind_type, 8) \ - instantiate_gather4(name, src_type, ind_type, 9) \ - instantiate_gather4(name, src_type, ind_type, 10) + instantiate_gather4(name, src_type, ind_type, 1) \ + instantiate_gather4(name, src_type, ind_type, 2) \ + instantiate_gather4(name, src_type, ind_type, 3) \ + instantiate_gather4(name, src_type, ind_type, 4) \ + instantiate_gather4(name, src_type, ind_type, 5) \ + instantiate_gather4(name, src_type, ind_type, 6) \ + instantiate_gather4(name, src_type, ind_type, 7) \ + instantiate_gather4(name, src_type, ind_type, 8) \ + instantiate_gather4(name, src_type, ind_type, 9) \ + instantiate_gather4(name, src_type, ind_type, 10) // clang-format on -#define instantiate_gather(name, src_type) \ - instantiate_gather3(#name "bool_", src_type, bool) \ - instantiate_gather3(#name "uint8", src_type, uint8_t) \ +// clang-format off +#define instantiate_gather(name, src_type) \ + instantiate_gather3(#name "bool_", src_type, bool) \ + instantiate_gather3(#name "uint8", src_type, uint8_t) \ instantiate_gather3(#name "uint16", src_type, uint16_t) \ instantiate_gather3(#name "uint32", src_type, uint32_t) \ instantiate_gather3(#name "uint64", src_type, uint64_t) \ - instantiate_gather3(#name "int8", src_type, int8_t) \ - instantiate_gather3(#name "int16", src_type, int16_t) \ - instantiate_gather3(#name "int32", src_type, int32_t) \ + instantiate_gather3(#name "int8", src_type, int8_t) \ + instantiate_gather3(#name "int16", src_type, int16_t) \ + instantiate_gather3(#name "int32", src_type, int32_t) \ instantiate_gather3(#name "int64", src_type, int64_t) instantiate_gather(bool_, bool) @@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t) instantiate_gather(int64, int64_t) instantiate_gather(float16, half) instantiate_gather(float32, float) -instantiate_gather(bfloat16, bfloat16_t) \ No newline at end of file +instantiate_gather(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 8b629ca6a..3f9025358 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. -#include #include +#include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/defines.h" @@ -18,32 +18,36 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN , /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ struct GEMVKernel { - static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); - // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up + // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up // into blocks of (BM * TM, BN * TN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each thead group is launched with (BN, BM, 1) threads // - // 1. A thread loads TN elements each from mat along TM contiguous rows + // 1. A thread loads TN elements each from mat along TM contiguous rows // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across the rows + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows // These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: - // - The threadgroup with the largest tid will have blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results remain zero) - // * The last thread that partially overlaps with the matrix is shifted inwards + // - The threadgroup with the largest tid will have blocks that exceed the + // matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards // such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; @@ -52,7 +56,7 @@ struct GEMVKernel { const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], + device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], @@ -64,14 +68,13 @@ struct GEMVKernel { uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - - // Appease compiler + // Appease compiler (void)lid; // Threadgroup in_vec cache threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2; - // Thread local accumulation results + // Thread local accumulation results thread T result[TM] = {0}; thread T inter[TN]; thread T v_coeff[TN]; @@ -80,7 +83,7 @@ struct GEMVKernel { int out_row = (tid.x * BM + simd_gid) * TM; // Exit simdgroup if rows out of bound - if(out_row >= out_vec_size) + if (out_row >= out_vec_size) return; // Adjust tail simdgroup to ensure in bound reads @@ -90,89 +93,80 @@ struct GEMVKernel { mat += out_row * marix_ld; // Loop over in_vec in blocks of BN * TN - for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { - + for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { threadgroup_barrier(mem_flags::mem_threadgroup); // Prefetch in_vector for threadgroup use - if(simd_gid == 0) { + if (simd_gid == 0) { // Main load loop - if(bn + TN <= in_vec_size) { - - #pragma clang loop unroll(full) - for(int tn = 0; tn < TN; tn++) { + if (bn + TN <= in_vec_size) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { in_vec_block[tn] = in_vec[bn + tn]; } } else { // Edgecase - #pragma clang loop unroll(full) - for(int tn = 0; tn < TN; tn++) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0; } - } } threadgroup_barrier(mem_flags::mem_threadgroup); - // Load for all rows - #pragma clang loop unroll(full) - for(int tn = 0; tn < TN; tn++) { +// Load for all rows +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { v_coeff[tn] = in_vec_block[tn]; } - // Per thread work loop - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { - - // Load for the row - if(bn + TN <= in_vec_size) { - #pragma clang loop unroll(full) - for(int tn = 0; tn < TN; tn++) { +// Per thread work loop +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + // Load for the row + if (bn + TN <= in_vec_size) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[tm * marix_ld + bn + tn]; } } else { // Edgecase - #pragma clang loop unroll(full) - for(int tn = 0; tn < TN; tn++) { - int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + int col_idx = + (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); inter[tn] = mat[tm * marix_ld + col_idx]; } } // Accumulate results - for(int tn = 0; tn < TN; tn++) { + for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } - } } - // Simdgroup accumulations - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { +// Simdgroup accumulations +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { result[tm] = simd_sum(result[tm]); } // Write outputs - if(simd_lid == 0) { - - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { - if(kDoAxpby) { - out_vec[out_row + tm] = - static_cast(alpha) * result[tm] + + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + if (kDoAxpby) { + out_vec[out_row + tm] = static_cast(alpha) * result[tm] + static_cast(beta) * bias[(out_row + tm) * bias_stride]; } else { out_vec[out_row + tm] = result[tm]; } } - } - } - }; /////////////////////////////////////////////////////////////////////////////// @@ -180,40 +174,43 @@ struct GEMVKernel { /////////////////////////////////////////////////////////////////////////////// template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ struct GEMVTKernel { - - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up // into blocks of (BM * TM, BN * TN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each thead group is launched with (BN, BM, 1) threads // - // 1. A thread loads TN elements each from mat along TM contiguous rows + // 1. A thread loads TN elements each from mat along TM contiguous rows // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across the rows + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows // These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: - // - The threadgroup with the largest tid will have blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results remain zero) - // * The last thread that partially overlaps with the matrix is shifted inwards + // - The threadgroup with the largest tid will have blocks that exceed the + // matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards // such that the thread block fits exactly in the matrix - MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; static METAL_FUNC void run( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], + device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], @@ -225,8 +222,7 @@ struct GEMVTKernel { uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - - // Appease compiler + // Appease compiler (void)simd_gid; (void)simd_lid; @@ -243,77 +239,69 @@ struct GEMVTKernel { // Edgecase handling if (out_col < out_vec_size) { - out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; // Per thread accumulation main loop int bm = in_row; - for(; bm < in_vec_size; bm += BM * TM) { + for (; bm < in_vec_size; bm += BM * TM) { // Adding a threadgroup_barrier improves performance slightly // This is possibly it may help exploit cache better threadgroup_barrier(mem_flags::mem_none); - if(bm + TM <= in_vec_size) { - - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { + if (bm + TM <= in_vec_size) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { v_coeff[tm] = in_vec[bm + tm]; } - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { - for(int tn = 0; tn < TN; tn++) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } - for(int tn = 0; tn < TN; tn++) { + for (int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; } } - + } else { // Edgecase handling - for(int tm = 0; bm + tm < in_vec_size; tm++) { + for (int tm = 0; bm + tm < in_vec_size; tm++) { v_coeff[tm] = in_vec[bm + tm]; - for(int tn = 0; tn < TN; tn++) { + for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } - for(int tn = 0; tn < TN; tn++) { + for (int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; } - } } } - } // Threadgroup collection - #pragma clang loop unroll(full) - for(int i = 0; i < TN; i++) { +#pragma clang loop unroll(full) + for (int i = 0; i < TN; i++) { tgp_results[lid.y * TN + i] = result[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Threadgroup accumulation and writing out results - if(lid.y == 0 && out_col < out_vec_size) { - - #pragma clang loop unroll(full) - for(int i = 1; i < BM; i++) { - - #pragma clang loop unroll(full) - for(int j = 0; j < TN; j++) { + if (lid.y == 0 && out_col < out_vec_size) { +#pragma clang loop unroll(full) + for (int i = 1; i < BM; i++) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { result[j] += tgp_results[i * TN + j]; } } - #pragma clang loop unroll(full) - for(int j = 0; j < TN; j++) { - - if(kDoAxpby) { - out_vec[out_col + j] = - static_cast(alpha) * result[j] + +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + if (kDoAxpby) { + out_vec[out_col + j] = static_cast(alpha) * result[j] + static_cast(beta) * bias[(out_col + j) * bias_stride]; } else { out_vec[out_col + j] = result[j]; @@ -328,18 +316,18 @@ struct GEMVTKernel { /////////////////////////////////////////////////////////////////////////////// template < - typename T, + typename T, const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv( +[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], + device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], @@ -355,16 +343,15 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; // Update batch offsets - if(kDoNCBatch) { + if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - if(kDoAxpby) { + if (kDoAxpby) { bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); } @@ -372,67 +359,64 @@ template < in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; - if(kDoAxpby) { + if (kDoAxpby) { bias += tid.z * bias_batch_stride[0]; } } out_vec += tid.z * out_vec_size; - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - tgp_memory, - tid, - lid, - simd_gid, - simd_lid - ); - + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid); } +#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ + template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \ + "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \ + gemv( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ - template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ - [[kernel]] void gemv( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const constant size_t* bias_batch_stride [[buffer(13)]], \ - const constant int& bias_stride [[buffer(14)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) -#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) - -#define instantiate_gemv_blocks(name, itype) \ - instantiate_gemv(name, itype, 4, 32, 1, 4) \ - instantiate_gemv(name, itype, 4, 32, 4, 4) \ - instantiate_gemv(name, itype, 8, 32, 4, 4) +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \ + name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4) instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); @@ -443,18 +427,18 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t); /////////////////////////////////////////////////////////////////////////////// template < - typename T, + typename T, const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t( +[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], + device T* out_vec [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], const constant int& marix_ld [[buffer(6)]], @@ -470,16 +454,15 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; // Update batch offsets - if(kDoNCBatch) { + if (kDoNCBatch) { in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - if(kDoAxpby) { + if (kDoAxpby) { bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); } @@ -487,70 +470,72 @@ template < in_vec += tid.z * vector_batch_stride[0]; mat += tid.z * matrix_batch_stride[0]; - if(kDoAxpby) { + if (kDoAxpby) { bias += tid.z * bias_batch_stride[0]; } } out_vec += tid.z * out_vec_size; - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - tgp_memory, - tid, - lid, - simd_gid, - simd_lid - ); - + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid); } -#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ - template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ - [[kernel]] void gemv_t( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const constant size_t* bias_batch_stride [[buffer(13)]], \ - const constant int& bias_stride [[buffer(14)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ + template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \ + "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \ + gemv_t( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ +// clang-format off +#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \ - instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) + instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on +// clang-format off #define instantiate_gemv_t_blocks(name, itype) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ + instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 128, 4, 4) + instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on +// clang-format off instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); -instantiate_gemv_t_blocks(bfloat16, bfloat16_t); +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 16028dce0..4c0dc7346 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -99,7 +99,8 @@ template for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; + out[i] = + w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } @@ -192,13 +193,15 @@ template if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { float xi = (x[r + i] - mean) * normalizer; - out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; + out[r + i] = + w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { float xi = (x[r + i] - mean) * normalizer; - out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; + out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + + b[b_stride * (i + r)]; } } } @@ -323,16 +326,18 @@ template if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast(normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - meanwg) - + thread_x[i] * meanwgxc * normalizer2); gw[i] = static_cast(thread_g[i] * thread_x[i]); } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast(normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - meanwg) - + thread_x[i] * meanwgxc * normalizer2); gw[i] = static_cast(thread_g[i] * thread_x[i]); } } @@ -460,8 +465,8 @@ template float xi = (x[i + r] - mean) * normalizer; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; - gx[i + r] = static_cast(normalizer * (wi * gi - meanwg) - - xi * meanwgxc * normalizer2); + gx[i + r] = static_cast( + normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); gw[i + r] = static_cast(gi * xi); } } else { @@ -470,8 +475,8 @@ template float xi = (x[i + r] - mean) * normalizer; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; - gx[i + r] = static_cast(normalizer * (wi * gi - meanwg) - - xi * meanwgxc * normalizer2); + gx[i + r] = static_cast( + normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); gw[i + r] = static_cast(gi * xi); } } @@ -548,6 +553,4 @@ template instantiate_layer_norm(float32, float) instantiate_layer_norm(float16, half) -instantiate_layer_norm(bfloat16, bfloat16_t) - // clang-format on - +instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 625d7450a..72ef5f103 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. -#include #include +#include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/defines.h" @@ -15,30 +15,31 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; - template -inline U load_vector(const device T *x, thread U *x_thread) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i+1] = x[i+1] / 4.0f; - x_thread[i+2] = x[i+2] / 16.0f; - x_thread[i+3] = x[i+3] / 64.0f; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i+1] = x[i+1] / 16.0f; - x_thread[i+2] = x[i+2] / 256.0f; - x_thread[i+3] = x[i+3] / 4096.0f; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; } } @@ -53,33 +54,35 @@ inline U load_vector(const device T *x, thread U *x_thread) { } template -inline U load_vector_safe(const device T *x, thread U *x_thread, int N) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i+1] = x[i+1] / 4.0f; - x_thread[i+2] = x[i+2] / 16.0f; - x_thread[i+3] = x[i+3] / 64.0f; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; } - for (int i=N; i -inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { - accum += ( - x_thread[4*i] * (w[i] & 0x03) - + x_thread[4*i+1] * (w[i] & 0x0c) - + x_thread[4*i+2] * (w[i] & 0x30) - + x_thread[4*i+3] * (w[i] & 0xc0)); + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { - accum += ( - x_thread[4*i] * (ws[i] & 0x000f) - + x_thread[4*i+1] * (ws[i] & 0x00f0) - + x_thread[4*i+2] * (ws[i] & 0x0f00) - + x_thread[4*i+3] * (ws[i] & 0xf000)); + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } @@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias } template -inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { - accum += ( - x_thread[4*i] * (w[i] & 0x03) - + x_thread[4*i+1] * (w[i] & 0x0c) - + x_thread[4*i+2] * (w[i] & 0x30) - + x_thread[4*i+3] * (w[i] & 0xc0)); + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { - accum += ( - x_thread[4*i] * (ws[i] & 0x000f) - + x_thread[4*i+1] * (ws[i] & 0x00f0) - + x_thread[4*i+2] * (ws[i] & 0x0f00) - + x_thread[4*i+3] * (ws[i] & 0xf000)); + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } @@ -170,16 +188,19 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U } template -inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { - result[4*i] += x * (s[0] * (w[i] & 0x03) + bias); - result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias); - result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias); - result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias); + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); } } @@ -187,10 +208,10 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu const thread uint16_t* ws = (const thread uint16_t*)w; U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { - result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias); - result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias); - result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias); - result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias); + result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias); + result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias); + result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias); + result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias); } } @@ -202,27 +223,38 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu } template -inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); if (bits == 2) { - U s[4] = {scale, scale / static_cast(4.0f), scale / static_cast(16.0f), scale / static_cast(64.0f)}; + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; for (int i = 0; i < (N / 4); i++) { - w_local[4*i] = s[0] * (w[i] & 0x03) + bias; - w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias; - w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias; - w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias; + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; - U s[4] = {scale, scale / static_cast(16.0f), scale / static_cast(256.0f), scale / static_cast(4096.0f)}; + U s[4] = { + scale, + scale / static_cast(16.0f), + scale / static_cast(256.0f), + scale / static_cast(4096.0f)}; for (int i = 0; i < (N / 4); i++) { - w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias; - w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias; - w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias; - w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias; + w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias; + w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias; + w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias; + w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias; } } @@ -243,13 +275,20 @@ template < short group_size, short bits> struct QuantizedBlockLoader { - static_assert(BCOLS <= group_size, "The group size should be larger than the columns"); - static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns"); - static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); MLX_MTL_CONST short pack_factor = 32 / bits; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = group_size / BCOLS; const int src_ld; @@ -275,7 +314,8 @@ struct QuantizedBlockLoader { ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + tile_stride( + reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), @@ -293,8 +333,9 @@ struct QuantizedBlockLoader { T scale = *scales; T bias = *biases; - for (int i=0; i((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); } } @@ -304,14 +345,14 @@ struct QuantizedBlockLoader { } if (reduction_dim == 1 && bi >= src_tile_dim.y) { - for (int i=0; i= src_tile_dim.x) { - for (int i=0; i((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); } } @@ -357,7 +399,6 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = 32 / bits; @@ -373,7 +414,8 @@ template // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; w += out_row * in_vec_size_w + simd_lid * packs_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; @@ -384,7 +426,8 @@ template U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -407,7 +450,6 @@ template } } - template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], @@ -420,7 +462,6 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; @@ -437,7 +478,8 @@ template // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); if (out_row >= out_vec_size) { @@ -454,17 +496,19 @@ template y += tid.z * out_vec_size + out_row; int k = 0; - for (; k < in_vec_size-block_size; k += block_size) { + for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + result[row] += + qdot(wl, x_thread, s, b, sum); } w += block_size / pack_factor; @@ -472,11 +516,16 @@ template biases += block_size / group_size; x += block_size; } - const int remaining = clamp(static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = load_vector_safe(x, x_thread, remaining); + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -502,17 +551,19 @@ template y += tid.z * out_vec_size + used_out_row; int k = 0; - for (; k < in_vec_size-block_size; k += block_size) { + for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + result[row] += + qdot(wl, x_thread, s, b, sum); } w += block_size / pack_factor; @@ -520,17 +571,23 @@ template biases += block_size / group_size; x += block_size; } - const int remaining = clamp(static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = load_vector_safe(x, x_thread, remaining); + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; - result[row] += qdot_safe(wl, x_thread, s, b, sum, remaining); + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); } for (int row = 0; row < results_per_simdgroup; row++) { @@ -542,7 +599,6 @@ template } } - template [[kernel]] void qvm( const device T* x [[buffer(0)]], @@ -555,7 +611,6 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 8; constexpr int pack_factor = 32 / bits; constexpr int blocksize = SIMD_SIZE; @@ -590,7 +645,8 @@ template bias = biases[(i + simd_lid) * out_vec_size_g]; w_local = w[(i + simd_lid) * out_vec_size_w]; - qouter((thread uint8_t *)&w_local, x_local, scale, bias, result); + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); } if (static_cast(i + simd_lid) < in_vec_size) { x_local = x[i + simd_lid]; @@ -603,25 +659,32 @@ template bias = 0; w_local = 0; } - qouter((thread uint8_t *)&w_local, x_local, scale, bias, result); + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); - // Accumulate in the simdgroup - #pragma clang loop unroll(full) - for (int k=0; k(result[k]); } } } - -template +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits, + const bool aligned_N> [[kernel]] void qmm_t( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -635,7 +698,6 @@ template = SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); @@ -647,9 +709,19 @@ template ; - using loader_x_t = mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader; + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; @@ -675,7 +747,7 @@ template +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits> [[kernel]] void qmm_n( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -743,7 +820,6 @@ template = SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); @@ -756,9 +832,19 @@ template ; - using loader_x_t = mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader; + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; @@ -780,8 +866,8 @@ template ( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); + template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \ + "_fast")]] [[kernel]] void \ + qmv_fast( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \ +// clang-format off +#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \ instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ - instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ - instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) + instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ + instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on +// clang-format off instantiate_qmv_fast_types(128, 2, 1) instantiate_qmv_fast_types(128, 4, 2) instantiate_qmv_fast_types(128, 8, 2) @@ -875,27 +963,30 @@ instantiate_qmv_fast_types( 64, 4, 2) instantiate_qmv_fast_types( 64, 8, 2) instantiate_qmv_fast_types( 32, 2, 1) instantiate_qmv_fast_types( 32, 4, 2) -instantiate_qmv_fast_types( 32, 8, 2) +instantiate_qmv_fast_types( 32, 8, 2) // clang-format on -#define instantiate_qmv(name, itype, group_size, bits) \ - template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qmv( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_qmv(name, itype, group_size, bits) \ + template [[host_name("qmv_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + qmv( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qmv_types(group_size, bits) \ +// clang-format off +#define instantiate_qmv_types(group_size, bits) \ instantiate_qmv(float32, float, group_size, bits) \ - instantiate_qmv(float16, half, group_size, bits) \ - instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) + instantiate_qmv(float16, half, group_size, bits) \ + instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on + // clang-format off instantiate_qmv_types(128, 2) instantiate_qmv_types(128, 4) instantiate_qmv_types(128, 8) @@ -904,27 +995,30 @@ instantiate_qmv_types( 64, 4) instantiate_qmv_types( 64, 8) instantiate_qmv_types( 32, 2) instantiate_qmv_types( 32, 4) -instantiate_qmv_types( 32, 8) +instantiate_qmv_types( 32, 8) // clang-format on -#define instantiate_qvm(name, itype, group_size, bits) \ - template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qvm( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_qvm(name, itype, group_size, bits) \ + template [[host_name("qvm_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + qvm( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qvm_types(group_size, bits) \ +// clang-format off +#define instantiate_qvm_types(group_size, bits) \ instantiate_qvm(float32, float, group_size, bits) \ - instantiate_qvm(float16, half, group_size, bits) \ - instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) + instantiate_qvm(float16, half, group_size, bits) \ + instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on + // clang-format off instantiate_qvm_types(128, 2) instantiate_qvm_types(128, 4) instantiate_qvm_types(128, 8) @@ -933,32 +1027,35 @@ instantiate_qvm_types( 64, 4) instantiate_qvm_types( 64, 8) instantiate_qvm_types( 32, 2) instantiate_qvm_types( 32, 4) -instantiate_qvm_types( 32, 8) +instantiate_qvm_types( 32, 8) // clang-format on -#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ - template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ - [[kernel]] void qmm_t( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& M [[buffer(5)]], \ - const constant int& N [[buffer(6)]], \ - const constant int& K [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ +#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ + template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \ + "_alN_" #aligned_N)]] [[kernel]] void \ + qmm_t( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& M [[buffer(5)]], \ + const constant int& N [[buffer(6)]], \ + const constant int& K [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qmm_t_types(group_size, bits) \ - instantiate_qmm_t(float32, float, group_size, bits, false) \ - instantiate_qmm_t(float16, half, group_size, bits, false) \ +// clang-format off +#define instantiate_qmm_t_types(group_size, bits) \ + instantiate_qmm_t(float32, float, group_size, bits, false) \ + instantiate_qmm_t(float16, half, group_size, bits, false) \ instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ - instantiate_qmm_t(float32, float, group_size, bits, true) \ - instantiate_qmm_t(float16, half, group_size, bits, true) \ - instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) + instantiate_qmm_t(float32, float, group_size, bits, true) \ + instantiate_qmm_t(float16, half, group_size, bits, true) \ + instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on + // clang-format off instantiate_qmm_t_types(128, 2) instantiate_qmm_t_types(128, 4) instantiate_qmm_t_types(128, 8) @@ -967,29 +1064,32 @@ instantiate_qmm_t_types( 64, 4) instantiate_qmm_t_types( 64, 8) instantiate_qmm_t_types( 32, 2) instantiate_qmm_t_types( 32, 4) -instantiate_qmm_t_types( 32, 8) +instantiate_qmm_t_types( 32, 8) // clang-format on #define instantiate_qmm_n(name, itype, group_size, bits) \ - template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qmm_n( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& M [[buffer(5)]], \ - const constant int& N [[buffer(6)]], \ - const constant int& K [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ + template [[host_name("qmm_n_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + qmm_n( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& M [[buffer(5)]], \ + const constant int& N [[buffer(6)]], \ + const constant int& K [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qmm_n_types(group_size, bits) \ +// clang-format off +#define instantiate_qmm_n_types(group_size, bits) \ instantiate_qmm_n(float32, float, group_size, bits) \ - instantiate_qmm_n(float16, half, group_size, bits) \ - instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) + instantiate_qmm_n(float16, half, group_size, bits) \ + instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on + // clang-format off instantiate_qmm_n_types(128, 2) instantiate_qmm_n_types(128, 4) instantiate_qmm_n_types(128, 8) @@ -998,4 +1098,4 @@ instantiate_qmm_n_types( 64, 4) instantiate_qmm_n_types( 64, 8) instantiate_qmm_n_types( 32, 2) instantiate_qmm_n_types( 32, 4) -instantiate_qmm_n_types( 32, 8) +instantiate_qmm_n_types( 32, 8) // clang-format on diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index cc9c5fe30..b0397dd66 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -3,9 +3,8 @@ #include "mlx/backend/metal/kernels/utils.h" static constexpr constant uint32_t rotations[2][4] = { - {13, 15, 26, 6}, - {17, 29, 16, 24} -}; + {13, 15, 26, 6}, + {17, 29, 16, 24}}; union rbits { uint2 val; @@ -13,7 +12,6 @@ union rbits { }; rbits threefry2x32_hash(const thread uint2& key, uint2 count) { - uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; rbits v; @@ -51,7 +49,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { out[4 * count.x + i] = bits.bytes[0][i]; } if (!drop_last) { - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[4 * count.y + i] = bits.bytes[1][i]; @@ -87,7 +85,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { out[4 * count.x + i] = bits.bytes[0][i]; } if (!drop_last) { - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[4 * count.y + i] = bits.bytes[1][i]; diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal index 46c75301a..83ed3e982 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/utils.h" #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h" +#include "mlx/backend/metal/kernels/reduction/utils.h" using namespace metal; @@ -60,14 +60,13 @@ METAL_FUNC U per_thread_all_reduce( // All reduce kernel /////////////////////////////////////////////////////////////////////////////// - // NB: This kernel assumes threads_per_threadgroup is at most // 1024. This way with a simd_size of 32, we are guaranteed to // complete the reduction in two steps of simd-level reductions. -template +template [[kernel]] void all_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device mlx_atomic* out [[buffer(1)]], const device size_t& in_size [[buffer(2)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], @@ -75,11 +74,11 @@ template uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; threadgroup U local_vals[simd_size]; - U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); + U total_val = + per_thread_all_reduce(in, in_size, gid, grid_size); // Reduction within simd group total_val = op.simd_reduce(total_val); @@ -98,10 +97,10 @@ template } } -template +template [[kernel]] void all_reduce_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const device size_t& in_size [[buffer(2)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], @@ -110,14 +109,16 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint thread_group_id [[threadgroup_position_in_grid]]) { - Op op; threadgroup U local_vals[simd_size]; - U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); + U total_val = + per_thread_all_reduce(in, in_size, gid, grid_size); - // Reduction within simd group (simd_add isn't supported for uint64/int64 types) - for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + // Reduction within simd group (simd_add isn't supported for uint64/int64 + // types) + for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; + lane_offset /= 2) { total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); } // Write simd group reduction results to local memory @@ -128,7 +129,8 @@ template // Reduction of simdgroup reduction results within threadgroup. total_val = lid < simd_per_group ? local_vals[lid] : op.init; - for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; + lane_offset /= 2) { total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); } @@ -138,31 +140,31 @@ template } } -#define instantiate_all_reduce(name, itype, otype, op) \ - template [[host_name("all_reduce_" #name)]] \ - [[kernel]] void all_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ +#define instantiate_all_reduce(name, itype, otype, op) \ + template [[host_name("all_reduce_" #name)]] [[kernel]] void \ + all_reduce( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ - template [[host_name("all_reduce_no_atomics_" #name)]] \ - [[kernel]] void all_reduce_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ +#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ + template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \ + all_reduce_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint thread_group_id [[threadgroup_position_in_grid]]); /////////////////////////////////////////////////////////////////////////////// @@ -170,11 +172,12 @@ template /////////////////////////////////////////////////////////////////////////////// #define instantiate_same_all_reduce_helper(name, tname, type, op) \ - instantiate_all_reduce(name ##tname, type, type, op) + instantiate_all_reduce(name##tname, type, type, op) #define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ - instantiate_all_reduce_no_atomics(name ##tname, type, type, op) + instantiate_all_reduce_no_atomics(name##tname, type, type, op) +// clang-format off instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) @@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) // special case bool with larger output type -instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) \ No newline at end of file +instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal index d757ee6dd..7e28b11ca 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/utils.h" #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h" +#include "mlx/backend/metal/kernels/reduction/utils.h" using namespace metal; @@ -12,8 +12,8 @@ using namespace metal; template [[kernel]] void col_reduce_small( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], @@ -25,7 +25,6 @@ template const constant size_t* non_col_strides [[buffer(10)]], const constant int& non_col_ndim [[buffer(11)]], uint tid [[thread_position_in_grid]]) { - // Appease the compiler (void)out_size; @@ -35,15 +34,16 @@ template auto out_idx = tid; in += elem_to_loc( - out_idx, - shape + non_col_ndim, - strides + non_col_ndim, - ndim - non_col_ndim); + out_idx, + shape + non_col_ndim, + strides + non_col_ndim, + ndim - non_col_ndim); - for(uint i = 0; i < non_col_reductions; i++) { - size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); + for (uint i = 0; i < non_col_reductions; i++) { + size_t in_idx = + elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); - for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { + for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { U val = static_cast(in[in_idx]); total_val = op(total_val, val); } @@ -52,21 +52,21 @@ template out[out_idx] = total_val; } -#define instantiate_col_reduce_small(name, itype, otype, op) \ - template [[host_name("col_reduce_small_" #name)]] \ - [[kernel]] void col_reduce_small( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - const constant size_t& non_col_reductions [[buffer(8)]], \ - const constant int* non_col_shapes [[buffer(9)]], \ - const constant size_t* non_col_strides [[buffer(10)]], \ - const constant int& non_col_ndim [[buffer(11)]], \ +#define instantiate_col_reduce_small(name, itype, otype, op) \ + template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \ + col_reduce_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + const constant size_t& non_col_reductions [[buffer(8)]], \ + const constant int* non_col_shapes [[buffer(9)]], \ + const constant size_t* non_col_strides [[buffer(10)]], \ + const constant int& non_col_ndim [[buffer(11)]], \ uint tid [[thread_position_in_grid]]); /////////////////////////////////////////////////////////////////////////////// @@ -112,39 +112,35 @@ METAL_FUNC U _contiguous_strided_reduce( template [[kernel]] void col_reduce_general( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device mlx_atomic* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], const constant int* shape [[buffer(5)]], const constant size_t* strides [[buffer(6)]], const constant int& ndim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], + threadgroup U* local_data [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) { auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc( - out_idx + tid.z * out_size, - shape, - strides, - ndim - ); + auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); Op op; - if(out_idx < out_size) { + if (out_idx < out_size) { U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); - // Write out reduction results generated by threadgroups working on specific output element, contiguously. + // Write out reduction results generated by threadgroups working on specific + // output element, contiguously. if (lid.y == 0) { op.atomic_update(out, val, out_idx); } @@ -153,40 +149,36 @@ template template [[kernel]] void col_reduce_general_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], const constant int* shape [[buffer(5)]], const constant size_t* strides [[buffer(6)]], const constant int& ndim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], + threadgroup U* local_data [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 gid [[thread_position_in_grid]], uint3 lsize [[threads_per_threadgroup]], uint3 gsize [[threads_per_grid]]) { auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc( - out_idx + tid.z * out_size, - shape, - strides, - ndim - ); + auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); - if(out_idx < out_size) { + if (out_idx < out_size) { U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); - // Write out reduction results generated by threadgroups working on specific output element, contiguously. + // Write out reduction results generated by threadgroups working on specific + // output element, contiguously. if (lid.y == 0) { uint tgsize_y = ceildiv(gsize.y, lsize.y); uint tgsize_z = ceildiv(gsize.z, lsize.z); @@ -195,52 +187,56 @@ template } } -#define instantiate_col_reduce_general(name, itype, otype, op) \ - template [[host_name("col_reduce_general_" #name)]] \ - [[kernel]] void col_reduce_general( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \ + col_reduce_general( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype* local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ uint3 lsize [[threads_per_threadgroup]]); -#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ - template [[host_name("col_reduce_general_no_atomics_" #name)]] \ - [[kernel]] void col_reduce_general_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 gid [[thread_position_in_grid]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]]); +#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ + template \ + [[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \ + col_reduce_general_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype* local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 gid [[thread_position_in_grid]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]]); /////////////////////////////////////////////////////////////////////////////// // Instantiations /////////////////////////////////////////////////////////////////////////////// -#define instantiate_same_col_reduce_helper(name, tname, type, op) \ +// clang-format off +#define instantiate_same_col_reduce_helper(name, tname, type, op) \ instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general(name ##tname, type, type, op) + instantiate_col_reduce_general(name ##tname, type, type, op) // clang-format on +// clang-format off #define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ - instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) + instantiate_col_reduce_small(name ##tname, type, type, op) \ + instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) // clang-format on +// clang-format off instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) @@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum) instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) \ No newline at end of file +instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal index 7e9bd06da..f0ca1c8b4 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/utils.h" #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h" +#include "mlx/backend/metal/kernels/reduction/utils.h" using namespace metal; @@ -12,22 +12,21 @@ using namespace metal; template [[kernel]] void init_reduce( - device T *out [[buffer(0)]], + device T* out [[buffer(0)]], uint tid [[thread_position_in_grid]]) { out[tid] = Op::init; } -#define instantiate_init_reduce(name, otype, op) \ - template [[host_name("i" #name)]] \ - [[kernel]] void init_reduce( \ - device otype *out [[buffer(1)]], \ - uint tid [[thread_position_in_grid]]); +#define instantiate_init_reduce(name, otype, op) \ + template [[host_name("i" #name)]] [[kernel]] void init_reduce( \ + device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]); #define instantiate_init_reduce_helper(name, tname, type, op) \ - instantiate_init_reduce(name ##tname, type, op) + instantiate_init_reduce(name##tname, type, op) +// clang-format off instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) instantiate_init_reduce(andbool_, bool, And) -instantiate_init_reduce(orbool_, bool, Or) \ No newline at end of file +instantiate_init_reduce(orbool_, bool, Or) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal index 499e791e2..22810ca20 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/utils.h" #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/reduce_inst.h" +#include "mlx/backend/metal/kernels/reduction/utils.h" using namespace metal; @@ -13,8 +13,8 @@ using namespace metal; // Each thread reduces for one output template [[kernel]] void row_reduce_general_small( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], const constant size_t& non_row_reductions [[buffer(4)]], @@ -22,22 +22,21 @@ template const constant size_t* strides [[buffer(6)]], const constant int& ndim [[buffer(7)]], uint lid [[thread_position_in_grid]]) { - Op op; - + uint out_idx = lid; - if(out_idx >= out_size) { + if (out_idx >= out_size) { return; } U total_val = Op::init; - for(short r = 0; r < short(non_row_reductions); r++) { + for (short r = 0; r < short(non_row_reductions); r++) { uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T * in_row = in + in_idx; - - for(short i = 0; i < short(reduction_size); i++) { + const device T* in_row = in + in_idx; + + for (short i = 0; i < short(reduction_size); i++) { total_val = op(static_cast(in_row[i]), total_val); } } @@ -48,8 +47,8 @@ template // Each simdgroup reduces for one output template [[kernel]] void row_reduce_general_med( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], const constant size_t& non_row_reductions [[buffer(4)]], @@ -60,45 +59,42 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - + uint out_idx = simd_per_group * tid + simd_group_id; - if(out_idx >= out_size) { + if (out_idx >= out_size) { return; } U total_val = Op::init; - if(short(non_row_reductions) == 1) { + if (short(non_row_reductions) == 1) { uint in_idx = elem_to_loc(out_idx, shape, strides, ndim); - const device T * in_row = in + in_idx; + const device T* in_row = in + in_idx; - for(short i = simd_lane_id; i < short(reduction_size); i += 32) { + for (short i = simd_lane_id; i < short(reduction_size); i += 32) { total_val = op(static_cast(in_row[i]), total_val); } } else if (short(non_row_reductions) >= 32) { - - for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) { - + for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) { uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T * in_row = in + in_idx; + const device T* in_row = in + in_idx; - for(short i = 0; i < short(reduction_size); i++) { + for (short i = 0; i < short(reduction_size); i++) { total_val = op(static_cast(in_row[i]), total_val); } - } } else { - - const short n_reductions = short(reduction_size) * short(non_row_reductions); - const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size; + const short n_reductions = + short(reduction_size) * short(non_row_reductions); + const short reductions_per_thread = + (n_reductions + simd_size - 1) / simd_size; const short r_st = simd_lane_id / reductions_per_thread; const short r_ed = short(non_row_reductions); @@ -108,54 +104,50 @@ template const short i_ed = short(reduction_size); const short i_jump = reductions_per_thread; - if(r_st < r_jump) { - for(short r = r_st; r < r_ed; r += r_jump) { - + if (r_st < r_jump) { + for (short r = r_st; r < r_ed; r += r_jump) { uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T * in_row = in + in_idx; + const device T* in_row = in + in_idx; - for(short i = i_st; i < i_ed; i += i_jump) { + for (short i = i_st; i < i_ed; i += i_jump) { total_val = op(static_cast(in_row[i]), total_val); } - } } - } - total_val = op.simd_reduce(total_val); - if(simd_lane_id == 0) { + if (simd_lane_id == 0) { out[out_idx] = total_val; } } -#define instantiate_row_reduce_small(name, itype, otype, op) \ - template[[host_name("row_reduce_general_small_" #name)]] \ - [[kernel]] void row_reduce_general_small( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint lid [[thread_position_in_grid]]); \ - template[[host_name("row_reduce_general_med_" #name)]] \ - [[kernel]] void row_reduce_general_med( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ +#define instantiate_row_reduce_small(name, itype, otype, op) \ + template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \ + row_reduce_general_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint lid [[thread_position_in_grid]]); \ + template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \ + row_reduce_general_med( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// @@ -217,10 +209,10 @@ METAL_FUNC U per_thread_row_reduce( return total_val; } -template +template [[kernel]] void row_reduce_general( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device mlx_atomic* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], const constant size_t& non_row_reductions [[buffer(4)]], @@ -233,25 +225,33 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; Op op; threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); + U total_val = per_thread_row_reduce( + in, + reduction_size, + out_size, + shape, + strides, + ndim, + lsize.x, + lid.x, + tid.xy); total_val = op.simd_reduce(total_val); - + // Prepare next level if (simd_lane_id == 0) { local_vals[simd_group_id] = total_val; } threadgroup_barrier(mem_flags::mem_threadgroup); - + // Reduction within thread group // Only needed if multiple simd groups - if(reduction_size > simd_size) { + if (reduction_size > simd_size) { total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; total_val = op.simd_reduce(total_val); } @@ -261,10 +261,10 @@ template } } -template +template [[kernel]] void row_reduce_general_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], const constant size_t& non_row_reductions [[buffer(4)]], @@ -278,16 +278,24 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; Op op; threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); + U total_val = per_thread_row_reduce( + in, + reduction_size, + out_size, + shape, + strides, + ndim, + lsize.x, + lid.x, + tid.xy); // Reduction within simd group - simd_add isn't supported for int64 types - for (uint16_t i = simd_size/2; i > 0; i /= 2) { + for (uint16_t i = simd_size / 2; i > 0; i /= 2) { total_val = op(total_val, simd_shuffle_down(total_val, i)); } @@ -299,9 +307,9 @@ template // Reduction within thread group // Only needed if thread group has multiple simd groups - if(ceildiv(reduction_size, N_READS) > simd_size) { + if (ceildiv(reduction_size, N_READS) > simd_size) { total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - for (uint16_t i = simd_size/2; i > 0; i /= 2) { + for (uint16_t i = simd_size / 2; i > 0; i /= 2) { total_val = op(total_val, simd_shuffle_down(total_val, i)); } } @@ -311,61 +319,60 @@ template } } -#define instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) \ - template [[host_name("row_reduce_general_" #name)]] \ - [[kernel]] void row_reduce_general( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) \ - template [[host_name("row_reduce_general_no_atomics_" #name)]] \ - [[kernel]] void row_reduce_general_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) template \ + [[host_name("row_reduce_general_" #name)]] [[kernel]] void \ + row_reduce_general( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) template \ + [[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \ + row_reduce_general_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// // Instantiations /////////////////////////////////////////////////////////////////////////////// #define instantiate_same_row_reduce_helper(name, tname, type, op) \ - instantiate_row_reduce_general(name ##tname, type, type, op) + instantiate_row_reduce_general(name##tname, type, type, op) #define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ - instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op) + instantiate_row_reduce_general_no_atomics(name##tname, type, type, op) +// clang-format off instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) - instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) -instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) \ No newline at end of file +instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index 475184925..b38aaa4da 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -237,13 +237,17 @@ template gw += gid * axis_size + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - gx[i] = static_cast(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); } } else { for (int i = 0; i < N_READS; i++) { if ((lid * N_READS + i) < axis_size) { - gx[i] = static_cast(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); } } @@ -342,7 +346,8 @@ template float wi = w[w_stride * (i + r)]; float gi = g[i + r]; - gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); gw[i + r] = static_cast(gi * xi * normalizer); } } else { @@ -352,7 +357,8 @@ template float wi = w[w_stride * (i + r)]; float gi = g[i + r]; - gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); gw[i + r] = static_cast(gi * xi * normalizer); } } @@ -431,5 +437,4 @@ template instantiate_rms(float32, float) instantiate_rms(float16, half) -instantiate_rms(bfloat16, bfloat16_t) - // clang-format on +instantiate_rms(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index 8188807d7..d53ec3845 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -7,8 +7,8 @@ template [[kernel]] void rope( - const device T *in [[buffer(0)]], - device T * out [[buffer(1)]], + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], constant const size_t strides[3], constant const size_t out_strides[3], constant const int& offset, @@ -20,12 +20,15 @@ template uint in_index_1, in_index_2; uint out_index_1, out_index_2; if (traditional) { - out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + pos.z * out_strides[0]; out_index_2 = out_index_1 + 1; - in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { - out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + pos.z * out_strides[0]; out_index_2 = out_index_1 + grid.x * out_strides[2]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_2 = in_index_1 + grid.x * strides[2]; @@ -57,18 +60,19 @@ template } #define instantiate_rope(name, type, traditional, forward) \ - template [[host_name("rope_" #name)]] \ - [[kernel]] void rope( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const int& offset, \ - constant const float& base, \ - constant const float& scale, \ - uint3 pos [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); + template [[host_name("rope_" #name)]] [[kernel]] void \ + rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const size_t strides[3], \ + constant const size_t out_strides[3], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); +// clang-format off instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(traditional_float32, float, true, true) @@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false) instantiate_rope(vjp_traditional_float32, float, true, false) instantiate_rope(vjp_float16, half, false, false) instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) -instantiate_rope(vjp_float32, float, false, false) +instantiate_rope(vjp_float32, float, false, false) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index fb9f0a111..1d8d9f9d6 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,451 +1,551 @@ -#include #include +#include #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" using namespace metal; -template -[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]], - const device T *K [[buffer(1)]], - const device T *V [[buffer(2)]], - const device uint64_t& L [[buffer(3)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], - device float* O_partials [[buffer(5)]], - device float* p_lse [[buffer(6)]], - device float* p_maxes [[buffer(7)]], - threadgroup T* threadgroup_block [[threadgroup(0)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - constexpr const size_t DK = 128; - constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; - constexpr const size_t THREADS_PER_SIMDGROUP = 32; - constexpr const uint iter_offset = NSIMDGROUPS * 4; - const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; - uint kv_head_offset_factor = tid.x; - if(is_gqa) { - int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; - kv_head_offset_factor = tid.x / q_kv_head_ratio; +template < + typename T, + typename T2, + typename T4, + uint16_t TILE_SIZE_CONST, + uint16_t NSIMDGROUPS> +[[kernel]] void fast_inference_sdpa_compute_partials_template( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + const device uint64_t& L [[buffer(3)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], + device float* O_partials [[buffer(5)]], + device float* p_lse [[buffer(6)]], + device float* p_maxes [[buffer(7)]], + threadgroup T* threadgroup_block [[threadgroup(0)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr const size_t DK = 128; + constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; + constexpr const size_t THREADS_PER_SIMDGROUP = 32; + constexpr const uint iter_offset = NSIMDGROUPS * 4; + const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; + uint kv_head_offset_factor = tid.x; + if (is_gqa) { + int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; + kv_head_offset_factor = tid.x / q_kv_head_ratio; + } + constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; + constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = + TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); + constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; + constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * + SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * + NSIMDGROUPS; + + threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; +#pragma clang loop unroll(full) + for (uint i = 0; i < 8; i++) { + smemFlush + [simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: multiple query sequence length for speculative decoding + const uint tgroup_query_head_offset = + tid.x * DK + tid.z * (params.N_Q_HEADS * DK); + + const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; + const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; + + const device T* baseK = + K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; + const device T* baseQ = Q + tgroup_query_head_offset; + + device T4* simdgroupQueryData = (device T4*)baseQ; + + constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; + float threadAccum[ACCUM_PER_GROUP]; + +#pragma clang loop unroll(full) + for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; + threadAccumIndex++) { + threadAccum[threadAccumIndex] = -INFINITY; + } + + uint KROW_ACCUM_INDEX = 0; + + const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; + const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; + const bool LAST_TILE_ALIGNED = + (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); + + T4 thread_data_x4; + T4 thread_data_y4; + if (!LAST_TILE || LAST_TILE_ALIGNED) { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); +#pragma clang loop unroll(full) + for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; + KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseK + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; } - constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; - constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); - constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; - constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS; + } else { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); + const uint START_ROW = tid.y * TILE_SIZE_CONST; + const device T* baseKThisHead = + K + tgroup_k_batch_offset + tgroup_k_head_offset; - threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; - #pragma clang loop unroll(full) - for(uint i = 0; i < 8; i++) { - smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); + for (size_t KROW = START_ROW + simd_group_id; KROW < L; + KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseKThisHead + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; } - threadgroup_barrier(mem_flags::mem_threadgroup); - // TODO: multiple query sequence length for speculative decoding - const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK); + } + threadgroup float* smemP = (threadgroup float*)threadgroup_block; - const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; - const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; - - const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; - const device T* baseQ = Q + tgroup_query_head_offset; - - device T4* simdgroupQueryData = (device T4*)baseQ; - - constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; - float threadAccum[ACCUM_PER_GROUP]; - - #pragma clang loop unroll(full) - for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) { - threadAccum[threadAccumIndex] = -INFINITY; +#pragma clang loop unroll(full) + for (size_t i = 0; i < P_VEC4; i++) { + thread_data_x4 = + T4(threadAccum[4 * i], + threadAccum[4 * i + 1], + threadAccum[4 * i + 2], + threadAccum[4 * i + 3]); + simdgroup_barrier(mem_flags::mem_none); + thread_data_y4 = simd_sum(thread_data_x4); + if (simd_lane_id == 0) { + const uint base_smem_p_offset = i * iter_offset + simd_group_id; + smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); + smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); + smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); + smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); } + } - uint KROW_ACCUM_INDEX = 0; + threadgroup_barrier(mem_flags::mem_threadgroup); - const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; - const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; - const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); + float groupMax; + float lse = 0.f; - T4 thread_data_x4; - T4 thread_data_y4; - if(!LAST_TILE || LAST_TILE_ALIGNED) { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); - #pragma clang loop unroll(full) - for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseK + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; + constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; + constexpr const size_t ACCUM_ARRAY_LENGTH = + TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; + float4 pvals[ACCUM_ARRAY_LENGTH]; + +#pragma clang loop unroll(full) + for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; + accum_array_iter++) { + pvals[accum_array_iter] = float4(-INFINITY); + } + + if (TILE_SIZE_CONST == 64) { + threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; + float2 vals = smemPtrFlt2[simd_lane_id]; + vals *= params.INV_ALPHA; + float maxval = max(vals.x, vals.y); + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float2 expf_shifted = exp(vals - groupMax); + float sumExpLocal = expf_shifted.x + expf_shifted.y; + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + + lse = log(tgroupExpSum); + float2 local_p_hat = expf_shifted / tgroupExpSum; + pvals[0].x = local_p_hat.x; + pvals[0].y = local_p_hat.y; + smemPtrFlt2[simd_lane_id] = float2(0.f); + } + constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; + constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; + + if (TILE_SIZE_LARGER_THAN_64) { + float maxval = -INFINITY; + threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; + vals *= params.INV_ALPHA; + pvals[i] = vals; + maxval = fmax3(vals.x, vals.y, maxval); + maxval = fmax3(vals.z, vals.w, maxval); + } + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float sumExpLocal = 0.f; +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = exp(pvals[i] - groupMax); + sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; + } + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + lse = log(tgroupExpSum); +#pragma clang loop unroll(full) + for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = pvals[i] / tgroupExpSum; + smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); + } + } + + threadgroup T* smemV = (threadgroup T*)threadgroup_block; + + const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; + const size_t v_head_offset = kv_head_offset_factor * L * DK; + + const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; + device T* baseV = (device T*)V + v_offset; + + threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); + + if (!LAST_TILE || LAST_TILE_ALIGNED) { +#pragma clang loop unroll(full) + for (size_t col = 0; col < MATRIX_COLS; col++) { + uint matrix_load_loop_iter = 0; + constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; + + for (size_t tile_start = simd_group_id; + tile_start < TILE_SIZE_CONST_DIV_8; + tile_start += NSIMDGROUPS) { + simdgroup_matrix tmp; + ulong simdgroup_matrix_offset = + matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + ulong2 matrixOrigin = + ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); + simdgroup_load(tmp, baseV, DK, matrixOrigin, true); + const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); + const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); + matrix_load_loop_iter++; + }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + uint loop_iter = 0; + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + +#pragma clang loop unroll(full) + for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; + row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + + T row_sum = simd_sum(val); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = + float(row_sum); + loop_iter++; } - } else { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); - const uint START_ROW = tid.y * TILE_SIZE_CONST; - const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset; + } - for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseKThisHead + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; + if (TILE_SIZE_CONST > 64) { + constexpr const size_t TILE_SIZE_CONST_DIV_128 = + (TILE_SIZE_CONST + 1) / 128; + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_iter = 0; + for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; + row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + + T row_sum = 0.f; + for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[i]); + T val = dot(p_local, v_local); + row_sum += val; + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = + float(row_sum); + loop_iter++; } + } } - threadgroup float* smemP = (threadgroup float*)threadgroup_block; + } else { + const int32_t START_ROW = tid.y * TILE_SIZE_CONST; + const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; + const device T* baseVThisHead = V + v_batch_offset + v_head_offset; + constexpr const int ROWS_PER_ITER = 8; +#pragma clang loop unroll(full) + for (size_t col = 0; col < MATRIX_COLS; col++) { + uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + int32_t tile_start; + for (tile_start = + START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + tile_start < MAX_START_ROW; + tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { + simdgroup_matrix tmp; + ulong2 matrixOrigin = + ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); + simdgroup_load( + tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); + const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); + constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store( + tmp, + smemV, + elemsPerRowSmem, + matrixOriginSmem, + /* transpose */ false); + smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; + }; - #pragma clang loop unroll(full) - for(size_t i = 0; i < P_VEC4; i++) { - thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]); - simdgroup_barrier(mem_flags::mem_none); - thread_data_y4 = simd_sum(thread_data_x4); - if(simd_lane_id == 0) { - const uint base_smem_p_offset = i * iter_offset + simd_group_id; - smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); - smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); - smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); - smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); + tile_start = + ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); + + const int32_t INT_L = int32_t(L); + for (int row_index = tile_start + simd_group_id; row_index < INT_L; + row_index += NSIMDGROUPS) { + if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { + const uint elems_per_row_gmem = DK; + const uint col_index_v_gmem = + col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; + const uint row_index_v_gmem = row_index; + + const uint elems_per_row_smem = TILE_SIZE_CONST; + const uint col_index_v_smem = row_index % TILE_SIZE_CONST; + const uint row_index_v_smem = simd_lane_id; + + const uint scalar_offset_gmem = + row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; + const uint scalar_offset_smem = + row_index_v_smem * elems_per_row_smem + col_index_v_smem; + T vdata = T(*(baseVThisHead + scalar_offset_gmem)); + smemV[scalar_offset_smem] = vdata; + smem_col_index += NSIMDGROUPS; } - } + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - float groupMax; - float lse = 0.f; - - constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; - constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; - float4 pvals[ACCUM_ARRAY_LENGTH]; - - #pragma clang loop unroll(full) - for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) { - pvals[accum_array_iter] = float4(-INFINITY); - } - - if (TILE_SIZE_CONST == 64) { - threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; - float2 vals = smemPtrFlt2[simd_lane_id]; - vals *= params.INV_ALPHA; - float maxval = max(vals.x, vals.y); - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); - - float2 expf_shifted = exp(vals - groupMax); - float sumExpLocal = expf_shifted.x + expf_shifted.y; - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - - lse = log(tgroupExpSum); - float2 local_p_hat = expf_shifted / tgroupExpSum; - pvals[0].x = local_p_hat.x; - pvals[0].y = local_p_hat.y; - smemPtrFlt2[simd_lane_id] = float2(0.f); - } - constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; - constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; - - if (TILE_SIZE_LARGER_THAN_64) { - float maxval = -INFINITY; - threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; - #pragma clang loop unroll(full) - for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { - float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; - vals *= params.INV_ALPHA; - pvals[i] = vals; - maxval = fmax3(vals.x, vals.y, maxval); - maxval = fmax3(vals.z, vals.w, maxval); + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + for (size_t smem_row_index = simd_group_id; + smem_row_index < ROWS_PER_ITER; + smem_row_index += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + T row_sum = simd_sum(val); + oPartialSmem[smem_row_index] = float(row_sum); } - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); + } - float sumExpLocal = 0.f; - #pragma clang loop unroll(full) - for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = exp(pvals[i] - groupMax); - sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; - } - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - lse = log(tgroupExpSum); - #pragma clang loop unroll(full) - for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = pvals[i] / tgroupExpSum; - smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); + if (TILE_SIZE_CONST > 64) { + threadgroup float* oPartialSmem = + smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_count = 0; + for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER; + row_index += NSIMDGROUPS) { + T row_sum = 0.f; + for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; + tile_iters++) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = + *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[tile_iters]); + row_sum += dot(p_local, v_local); + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = + float(row_sum); + loop_count++; } + } } + } - threadgroup T* smemV = (threadgroup T*)threadgroup_block; + threadgroup_barrier(mem_flags::mem_threadgroup); - const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; - const size_t v_head_offset = kv_head_offset_factor * L * DK; + if (simd_group_id == 0) { + threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; + float4 vals = *(oPartialVec4 + simd_lane_id); + device float* oPartialGmem = + O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; + device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; + oPartialGmemVec4[simd_lane_id] = vals; + } - const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; - device T* baseV = (device T*)V + v_offset; - - threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); - - if (!LAST_TILE || LAST_TILE_ALIGNED) { - #pragma clang loop unroll(full) - for(size_t col = 0; col < MATRIX_COLS; col++) { - uint matrix_load_loop_iter = 0; - constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; - - for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) { - simdgroup_matrix tmp; - ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); - simdgroup_load(tmp, baseV, DK, matrixOrigin, true); - const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); - const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); - matrix_load_loop_iter++; - }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - uint loop_iter = 0; - threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - - #pragma clang loop unroll(full) - for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - - T row_sum = simd_sum(val); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); - loop_iter++; - } - } - - if (TILE_SIZE_CONST > 64) { - constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128; - threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_iter = 0; - for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - - T row_sum = 0.f; - for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[i]); - T val = dot(p_local, v_local); - row_sum += val; - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); - loop_iter++; - } - } - } - } else { - const int32_t START_ROW = tid.y * TILE_SIZE_CONST; - const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; - const device T* baseVThisHead = V + v_batch_offset + v_head_offset; - constexpr const int ROWS_PER_ITER = 8; - #pragma clang loop unroll(full) - for(size_t col = 0; col < MATRIX_COLS; col++) { - uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - int32_t tile_start; - for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { - simdgroup_matrix tmp; - ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); - simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); - const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); - constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false); - smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; - }; - - tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); - - const int32_t INT_L = int32_t(L); - for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) { - if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { - const uint elems_per_row_gmem = DK; - const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; - const uint row_index_v_gmem = row_index; - - const uint elems_per_row_smem = TILE_SIZE_CONST; - const uint col_index_v_smem = row_index % TILE_SIZE_CONST; - const uint row_index_v_smem = simd_lane_id; - - const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; - const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem; - T vdata = T(*(baseVThisHead + scalar_offset_gmem)); - smemV[scalar_offset_smem] = vdata; - smem_col_index += NSIMDGROUPS; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - for(size_t smem_row_index = simd_group_id; - smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - T row_sum = simd_sum(val); - oPartialSmem[smem_row_index] = float(row_sum); - } - } - - if (TILE_SIZE_CONST > 64) { - threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_count = 0; - for(size_t row_index = simd_group_id; - row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) { - T row_sum = 0.f; - for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[tile_iters]); - row_sum += dot(p_local, v_local); - - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum); - loop_count++; - } - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if(simd_group_id == 0) { - threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; - float4 vals = *(oPartialVec4 + simd_lane_id); - device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; - device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; - oPartialGmemVec4[simd_lane_id] = vals; - } - - if(simd_group_id == 0 && simd_lane_id == 0) { - const uint tileIndex = tid.y; - const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex; - p_lse[gmem_partial_scalar_offset] = lse; - p_maxes[gmem_partial_scalar_offset] = groupMax; - } + if (simd_group_id == 0 && simd_lane_id == 0) { + const uint tileIndex = tid.y; + const uint gmem_partial_scalar_offset = + tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + + tileIndex; + p_lse[gmem_partial_scalar_offset] = lse; + p_maxes[gmem_partial_scalar_offset] = groupMax; + } } -#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \ -template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \ -[[kernel]] void fast_inference_sdpa_compute_partials_template( \ - const device itype *Q [[buffer(0)]], \ - const device itype *K [[buffer(1)]], \ - const device itype *V [[buffer(2)]], \ - const device uint64_t& L [[buffer(3)]], \ - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ - device float* O_partials [[buffer(5)]], \ - device float* p_lse [[buffer(6)]], \ - device float* p_maxes [[buffer(7)]], \ - threadgroup itype *threadgroup_block [[threadgroup(0)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]]); +#define instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, nsimdgroups) \ + template [[host_name("fast_inference_sdpa_compute_partials_" #itype \ + "_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \ + fast_inference_sdpa_compute_partials_template< \ + itype, \ + itype2, \ + itype4, \ + tile_size, \ + nsimdgroups>( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + const device uint64_t& L [[buffer(3)]], \ + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ + device float* O_partials [[buffer(5)]], \ + device float* p_lse [[buffer(6)]], \ + device float* p_maxes [[buffer(7)]], \ + threadgroup itype* threadgroup_block [[threadgroup(0)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]]); +// clang-format off +#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \ + itype, itype2, itype4, tile_size) \ + instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, 4) \ + instantiate_fast_inference_sdpa_to_partials_kernel( \ + itype, itype2, itype4, tile_size, 8) // clang-format on -#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \ - instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \ - instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \ - -instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512); - -instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + float, + float2, + float4, + 512); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper( + half, + half2, + half4, + 512); template void fast_inference_sdpa_reduce_tiles_template( - const device float *O_partials [[buffer(0)]], - const device float *p_lse[[buffer(1)]], - const device float *p_maxes [[buffer(2)]], + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], device T* O [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { + constexpr const int DK = 128; + const ulong offset_rows = + tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; + const device float* p_lse_row = p_lse + offset_rows; + const device float* p_rowmax_row = p_maxes + offset_rows; + // reserve some number of registers. this constitutes an assumption on max + // value of KV TILES. + constexpr const uint8_t reserve = 128; + float p_lse_regs[reserve]; + float p_rowmax_regs[reserve]; + float weights[reserve]; - constexpr const int DK = 128; - const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; - const device float* p_lse_row = p_lse + offset_rows; - const device float* p_rowmax_row = p_maxes + offset_rows; - // reserve some number of registers. this constitutes an assumption on max value of KV TILES. - constexpr const uint8_t reserve = 128; - float p_lse_regs[reserve]; - float p_rowmax_regs[reserve]; - float weights[reserve]; + float true_max = -INFINITY; + for (size_t i = 0; i < params.KV_TILES; i++) { + p_lse_regs[i] = float(*(p_lse_row + i)); + p_rowmax_regs[i] = float(*(p_rowmax_row + i)); + true_max = fmax(p_rowmax_regs[i], true_max); + weights[i] = exp(p_lse_regs[i]); + } - float true_max = -INFINITY; - for(size_t i = 0; i < params.KV_TILES; i++) { - p_lse_regs[i] = float(*(p_lse_row + i)); - p_rowmax_regs[i] = float(*(p_rowmax_row + i)); - true_max = fmax(p_rowmax_regs[i], true_max); - weights[i] = exp(p_lse_regs[i]); - } + float denom = 0.f; + for (size_t i = 0; i < params.KV_TILES; i++) { + weights[i] *= exp(p_rowmax_regs[i] - true_max); + denom += weights[i]; + } - float denom = 0.f; - for(size_t i = 0; i < params.KV_TILES; i++) { - weights[i] *= exp(p_rowmax_regs[i]-true_max); - denom += weights[i]; - } + const device float* O_partials_with_offset = O_partials + + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + + tid.x * DK * params.KV_TILES; - const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES; - - float o_value = 0.f; - for(size_t i = 0; i < params.KV_TILES; i++) { - float val = *(O_partials_with_offset + i * DK + lid.x); - o_value += val * weights[i] / denom; - } - device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; - O_gmem[lid.x] = T(o_value); - return; + float o_value = 0.f; + for (size_t i = 0; i < params.KV_TILES; i++) { + float val = *(O_partials_with_offset + i * DK + lid.x); + o_value += val * weights[i] / denom; + } + device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; + O_gmem[lid.x] = T(o_value); + return; } - kernel void fast_inference_sdpa_reduce_tiles_float( - const device float *O_partials [[buffer(0)]], - const device float *p_lse[[buffer(1)]], - const device float *p_maxes [[buffer(2)]], + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], device float* O [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) -{ - fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params, - O, tid, lid); + uint3 lid [[thread_position_in_threadgroup]]) { + fast_inference_sdpa_reduce_tiles_template( + O_partials, p_lse, p_maxes, params, O, tid, lid); } kernel void fast_inference_sdpa_reduce_tiles_half( - const device float *O_partials [[buffer(0)]], - const device float *p_lse[[buffer(1)]], - const device float *p_maxes [[buffer(2)]], + const device float* O_partials [[buffer(0)]], + const device float* p_lse [[buffer(1)]], + const device float* p_maxes [[buffer(2)]], const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], device half* O [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) -{ - fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params, - O, tid, lid); + uint3 lid [[thread_position_in_threadgroup]]) { + fast_inference_sdpa_reduce_tiles_template( + O_partials, p_lse, p_maxes, params, O, tid, lid); } diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 13b7e7adc..42533c7ea 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -54,7 +54,7 @@ struct CumProd { } bool simd_scan(bool x) { - for (int i=1; i<=16; i*=2) { + for (int i = 1; i <= 16; i *= 2) { bool other = simd_shuffle_up(x, i); x &= other; } @@ -77,7 +77,7 @@ struct CumMax { } U simd_scan(U x) { - for (int i=1; i<=16; i*=2) { + for (int i = 1; i <= 16; i *= 2) { U other = simd_shuffle_up(x, i); x = (x >= other) ? x : other; } @@ -100,7 +100,7 @@ struct CumMin { } U simd_scan(U x) { - for (int i=1; i<=16; i*=2) { + for (int i = 1; i <= 16; i *= 2) { U other = simd_shuffle_up(x, i); x = (x <= other) ? x : other; } @@ -114,54 +114,60 @@ struct CumMin { }; template -inline void load_unsafe(U values[N_READS], const device T * input) { +inline void load_unsafe(U values[N_READS], const device T* input) { if (reverse) { - for (int i=0; i -inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) { +inline void load_safe( + U values[N_READS], + const device T* input, + int start, + int total, + U init) { if (reverse) { - for (int i=0; i -inline void write_unsafe(U values[N_READS], device U * out) { +inline void write_unsafe(U values[N_READS], device U* out) { if (reverse) { - for (int i=0; i -inline void write_safe(U values[N_READS], device U * out, int start, int total) { +inline void write_safe(U values[N_READS], device U* out, int start, int total) { if (reverse) { - for (int i=0; i +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> [[kernel]] void contiguous_scan( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], - const constant size_t & axis_size [[buffer(2)]], + const constant size_t& axis_size [[buffer(2)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], @@ -195,42 +206,51 @@ template (values, in + axis_size - offset - N_READS); + load_unsafe( + values, in + axis_size - offset - N_READS); } else { - load_safe(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init); + load_safe( + values, + in + axis_size - offset - N_READS, + offset, + axis_size, + Op::init); } } else { if ((offset + N_READS) < axis_size) { load_unsafe(values, in + offset); } else { - load_safe(values, in + offset, offset, axis_size, Op::init); + load_safe( + values, in + offset, offset, axis_size, Op::init); } } // Compute an inclusive scan per thread - for (int i=1; i(values, out + axis_size - offset - N_READS); + write_unsafe( + values, out + axis_size - offset - N_READS); } else { - write_safe(values, out + axis_size - offset - N_READS, offset, axis_size); + write_safe( + values, out + axis_size - offset - N_READS, offset, axis_size); } } else { if (lid == 0 && offset == 0) { - out[axis_size-1] = Op::init; + out[axis_size - 1] = Op::init; } if ((offset + N_READS + 1) < axis_size) { - write_unsafe(values, out + axis_size - offset - 1 - N_READS); + write_unsafe( + values, out + axis_size - offset - 1 - N_READS); } else { - write_safe(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size); + write_safe( + values, + out + axis_size - offset - 1 - N_READS, + offset + 1, + axis_size); } } } else { @@ -275,7 +302,8 @@ template (values, out + offset); } else { - write_safe(values, out + offset, offset, axis_size); + write_safe( + values, out + offset, offset, axis_size); } } else { if (lid == 0 && offset == 0) { @@ -284,26 +312,33 @@ template (values, out + offset + 1); } else { - write_safe(values, out + offset + 1, offset + 1, axis_size); + write_safe( + values, out + offset + 1, offset + 1, axis_size); } } } // Share the prefix if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { - simdgroup_sums[0] = values[N_READS-1]; + simdgroup_sums[0] = values[N_READS - 1]; } threadgroup_barrier(mem_flags::mem_threadgroup); prefix = simdgroup_sums[0]; } } -template +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> [[kernel]] void strided_scan( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], - const constant size_t & axis_size [[buffer(2)]], - const constant size_t & stride [[buffer(3)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], uint2 gid [[threadgroup_position_in_grid]], uint2 lid [[thread_position_in_threadgroup]], uint2 lsize [[threads_per_threadgroup]], @@ -311,10 +346,10 @@ template , nreads, inclusive, reverse>( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t & axis_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_contiguous_scan( \ + name, itype, otype, op, inclusive, reverse, nreads) \ + template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \ + contiguous_scan, nreads, inclusive, reverse>( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& axis_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \ - template [[host_name("strided_scan_" #name)]] \ - [[kernel]] void strided_scan, nreads, inclusive, reverse>( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t & axis_size [[buffer(2)]], \ - const constant size_t & stride [[buffer(3)]], \ - uint2 gid [[thread_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]]); +#define instantiate_strided_scan( \ + name, itype, otype, op, inclusive, reverse, nreads) \ + template [[host_name("strided_scan_" #name)]] [[kernel]] void \ + strided_scan, nreads, inclusive, reverse>( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& axis_size [[buffer(2)]], \ + const constant size_t& stride [[buffer(3)]], \ + uint2 gid [[thread_position_in_grid]], \ + uint2 lid [[thread_position_in_threadgroup]], \ + uint2 lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]]); - -#define instantiate_scan_helper(name, itype, otype, op, nreads) \ - instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ - instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ - instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ +// clang-format off +#define instantiate_scan_helper(name, itype, otype, op, nreads) \ + instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ + instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ + instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \ - instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ - instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ - instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ - instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) + instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ + instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ + instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ + instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on +// clang-format off instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4) instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) @@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) +//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal index fd2bd1950..89cc6f22d 100644 --- a/mlx/backend/metal/kernels/scatter.metal +++ b/mlx/backend/metal/kernels/scatter.metal @@ -13,67 +13,55 @@ using namespace metal; // Scatter kernel ///////////////////////////////////////////////////////////////////// -template \ +template METAL_FUNC void scatter_1d_index_impl( - const device T *updates [[buffer(1)]], - device mlx_atomic *out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& upd_size [[buffer(5)]], - const thread array& idx_buffers, - uint2 gid [[thread_position_in_grid]]) { - + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& upd_size [[buffer(5)]], + const thread array& idx_buffers, + uint2 gid [[thread_position_in_grid]]) { Op op; uint out_idx = 0; for (int i = 0; i < NIDX; i++) { - auto idx_val = offset_neg_idx( - idx_buffers[i][gid.y], out_shape[i]); + auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); out_idx += idx_val * out_strides[i]; } op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); } -#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ -template \ -[[kernel]] void scatter_1d_index( \ - const device T *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const constant int* out_shape [[buffer(3)]], \ - const constant size_t* out_strides [[buffer(4)]], \ - const constant size_t& upd_size [[buffer(5)]], \ - IDX_ARG(IdxT) \ - uint2 gid [[thread_position_in_grid]]) { \ - \ - const array idx_buffers = {IDX_ARR()}; \ - \ - return scatter_1d_index_impl( \ - updates, \ - out, \ - out_shape, \ - out_strides, \ - upd_size, \ - idx_buffers, \ - gid); \ - \ -} +#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ + template \ + [[kernel]] void scatter_1d_index( \ + const device T* updates [[buffer(1)]], \ + device mlx_atomic* out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \ + const array idx_buffers = {IDX_ARR()}; \ + \ + return scatter_1d_index_impl( \ + updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \ + } template METAL_FUNC void scatter_impl( - const device T *updates [[buffer(1)]], - device mlx_atomic *out [[buffer(2)]], - const constant int *upd_shape [[buffer(3)]], - const constant size_t *upd_strides [[buffer(4)]], + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant size_t* upd_strides [[buffer(4)]], const constant size_t& upd_ndim [[buffer(5)]], const constant size_t& upd_size [[buffer(6)]], - const constant int *out_shape [[buffer(7)]], - const constant size_t *out_strides [[buffer(8)]], + const constant int* out_shape [[buffer(7)]], + const constant size_t* out_strides [[buffer(8)]], const constant size_t& out_ndim [[buffer(9)]], const constant int* axes [[buffer(10)]], const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { - Op op; auto ind_idx = gid.y; auto ind_offset = gid.x; @@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl( &indices.strides[indices.ndim * i], indices.ndim); auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], out_shape[ax]); + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); out_idx += idx_val * out_strides[ax]; } @@ -97,142 +84,134 @@ METAL_FUNC void scatter_impl( out_idx += out_offset; } - auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + auto upd_idx = + elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); op.atomic_update(out, updates[upd_idx], out_idx); } -#define make_scatter_impl(IDX_ARG, IDX_ARR) \ -template \ -[[kernel]] void scatter( \ - const device T *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const constant int *upd_shape [[buffer(3)]], \ - const constant size_t *upd_strides [[buffer(4)]], \ - const constant size_t& upd_ndim [[buffer(5)]], \ - const constant size_t& upd_size [[buffer(6)]], \ - const constant int *out_shape [[buffer(7)]], \ - const constant size_t *out_strides [[buffer(8)]], \ - const constant size_t& out_ndim [[buffer(9)]], \ - const constant int* axes [[buffer(10)]], \ - const constant int *idx_shapes [[buffer(11)]], \ - const constant size_t *idx_strides [[buffer(12)]], \ - const constant int& idx_ndim [[buffer(13)]], \ - IDX_ARG(IdxT) \ - uint2 gid [[thread_position_in_grid]]) { \ - \ - Indices idxs{ \ - {{IDX_ARR()}}, \ - idx_shapes, \ - idx_strides, \ - idx_ndim}; \ - \ - return scatter_impl( \ - updates, \ - out, \ - upd_shape, \ - upd_strides, \ - upd_ndim, \ - upd_size, \ - out_shape, \ - out_strides, \ - out_ndim, \ - axes, \ - idxs, \ - gid); \ -} +#define make_scatter_impl(IDX_ARG, IDX_ARR) \ + template \ + [[kernel]] void scatter( \ + const device T* updates [[buffer(1)]], \ + device mlx_atomic* out [[buffer(2)]], \ + const constant int* upd_shape [[buffer(3)]], \ + const constant size_t* upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int* out_shape [[buffer(7)]], \ + const constant size_t* out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int* idx_shapes [[buffer(11)]], \ + const constant size_t* idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \ + Indices idxs{ \ + {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \ + \ + return scatter_impl( \ + updates, \ + out, \ + upd_shape, \ + upd_strides, \ + upd_ndim, \ + upd_size, \ + out_shape, \ + out_strides, \ + out_ndim, \ + axes, \ + idxs, \ + gid); \ + } -#define make_scatter(n) \ -make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \ -make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n) +#define make_scatter(n) \ + make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \ + make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n) -make_scatter(0) -make_scatter(1) -make_scatter(2) -make_scatter(3) -make_scatter(4) -make_scatter(5) -make_scatter(6) -make_scatter(7) -make_scatter(8) -make_scatter(9) -make_scatter(10) +make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4) + make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8) + make_scatter(9) make_scatter(10) ///////////////////////////////////////////////////////////////////// // Scatter instantiations ///////////////////////////////////////////////////////////////////// #define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ -template [[host_name("scatter" name "_" #nidx)]] \ -[[kernel]] void scatter( \ - const device src_t *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const constant int *upd_shape [[buffer(3)]], \ - const constant size_t *upd_strides [[buffer(4)]], \ - const constant size_t& upd_ndim [[buffer(5)]], \ - const constant size_t& upd_size [[buffer(6)]], \ - const constant int *out_shape [[buffer(7)]], \ - const constant size_t *out_strides [[buffer(8)]], \ - const constant size_t& out_ndim [[buffer(9)]], \ - const constant int* axes [[buffer(10)]], \ - const constant int *idx_shapes [[buffer(11)]], \ - const constant size_t *idx_strides [[buffer(12)]], \ - const constant int& idx_ndim [[buffer(13)]], \ - IDX_ARG(idx_t) \ - uint2 gid [[thread_position_in_grid]]); + template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \ + scatter( \ + const device src_t* updates [[buffer(1)]], \ + device mlx_atomic* out [[buffer(2)]], \ + const constant int* upd_shape [[buffer(3)]], \ + const constant size_t* upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int* out_shape [[buffer(7)]], \ + const constant size_t* out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int* idx_shapes [[buffer(11)]], \ + const constant size_t* idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]); -#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ -template [[host_name("scatter_1d_index" name "_" #nidx)]] \ -[[kernel]] void scatter_1d_index( \ - const device src_t *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const constant int* out_shape [[buffer(3)]], \ - const constant size_t* out_strides [[buffer(4)]], \ - const constant size_t& upd_size [[buffer(5)]], \ - IDX_ARG(idx_t) \ - uint2 gid [[thread_position_in_grid]]); +#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ + template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \ + scatter_1d_index( \ + const device src_t* updates [[buffer(1)]], \ + device mlx_atomic* out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]); -#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ +// clang-format off +#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \ - instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on +// clang-format off // Special case NINDEX=0 -#define instantiate_scatter_nd0(name, type) \ - instantiate_scatter4(#name "none", type, bool, None, 0) \ - instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ +#define instantiate_scatter_nd0(name, type) \ + instantiate_scatter4(#name "none", type, bool, None, 0) \ + instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ - instantiate_scatter4(#name "_max", type, bool, Max, 0) \ - instantiate_scatter4(#name "_min", type, bool, Min, 0) + instantiate_scatter4(#name "_max", type, bool, Max, 0) \ + instantiate_scatter4(#name "_min", type, bool, Min, 0) // clang-format on +// clang-format off #define instantiate_scatter3(name, type, ind_type, op_type) \ - instantiate_scatter4(name, type, ind_type, op_type, 1) \ - instantiate_scatter4(name, type, ind_type, op_type, 2) \ - instantiate_scatter4(name, type, ind_type, op_type, 3) \ - instantiate_scatter4(name, type, ind_type, op_type, 4) \ - instantiate_scatter4(name, type, ind_type, op_type, 5) \ - instantiate_scatter4(name, type, ind_type, op_type, 6) \ - instantiate_scatter4(name, type, ind_type, op_type, 7) \ - instantiate_scatter4(name, type, ind_type, op_type, 8) \ - instantiate_scatter4(name, type, ind_type, op_type, 9) \ - instantiate_scatter4(name, type, ind_type, op_type, 10) + instantiate_scatter4(name, type, ind_type, op_type, 1) \ + instantiate_scatter4(name, type, ind_type, op_type, 2) \ + instantiate_scatter4(name, type, ind_type, op_type, 3) \ + instantiate_scatter4(name, type, ind_type, op_type, 4) \ + instantiate_scatter4(name, type, ind_type, op_type, 5) \ + instantiate_scatter4(name, type, ind_type, op_type, 6) \ + instantiate_scatter4(name, type, ind_type, op_type, 7) \ + instantiate_scatter4(name, type, ind_type, op_type, 8) \ + instantiate_scatter4(name, type, ind_type, op_type, 9) \ + instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on -#define instantiate_scatter2(name, type, ind_type) \ - instantiate_scatter3(name "_none", type, ind_type, None) \ - instantiate_scatter3(name "_sum", type, ind_type, Sum) \ +// clang-format off +#define instantiate_scatter2(name, type, ind_type) \ + instantiate_scatter3(name "_none", type, ind_type, None) \ + instantiate_scatter3(name "_sum", type, ind_type, Sum) \ instantiate_scatter3(name "_prod", type, ind_type, Prod) \ - instantiate_scatter3(name "_max", type, ind_type, Max) \ - instantiate_scatter3(name "_min", type, ind_type, Min) + instantiate_scatter3(name "_max", type, ind_type, Max) \ + instantiate_scatter3(name "_min", type, ind_type, Min) // clang-format on -#define instantiate_scatter(name, type) \ - instantiate_scatter2(#name "bool_", type, bool) \ - instantiate_scatter2(#name "uint8", type, uint8_t) \ +// clang-format off +#define instantiate_scatter(name, type) \ + instantiate_scatter2(#name "bool_", type, bool) \ + instantiate_scatter2(#name "uint8", type, uint8_t) \ instantiate_scatter2(#name "uint16", type, uint16_t) \ instantiate_scatter2(#name "uint32", type, uint32_t) \ instantiate_scatter2(#name "uint64", type, uint64_t) \ - instantiate_scatter2(#name "int8", type, int8_t) \ - instantiate_scatter2(#name "int16", type, int16_t) \ - instantiate_scatter2(#name "int32", type, int32_t) \ - instantiate_scatter2(#name "int64", type, int64_t) + instantiate_scatter2(#name "int8", type, int8_t) \ + instantiate_scatter2(#name "int16", type, int16_t) \ + instantiate_scatter2(#name "int32", type, int32_t) \ + instantiate_scatter2(#name "int64", type, int64_t) // clang-format on + // clang-format off // TODO uint64 and int64 unsupported instantiate_scatter_nd0(bool_, bool) instantiate_scatter_nd0(uint8, uint8_t) @@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t) instantiate_scatter(int32, int32_t) instantiate_scatter(float16, half) instantiate_scatter(float32, float) -instantiate_scatter(bfloat16, bfloat16_t) +instantiate_scatter(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 4bf9f2916..aaea6a279 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -198,17 +198,16 @@ template } } -// clang-format off -#define instantiate_softmax(name, itype) \ - template [[host_name("softmax_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ +#define instantiate_softmax(name, itype) \ + template [[host_name("softmax_" #name)]] [[kernel]] void \ + softmax_single_row( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + uint gid [[thread_position_in_grid]], \ + uint _lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ template [[host_name("softmax_looped_" #name)]] [[kernel]] void \ softmax_looped( \ const device itype* in, \ @@ -220,16 +219,16 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_softmax_precise(name, itype) \ - template [[host_name("softmax_precise_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ +#define instantiate_softmax_precise(name, itype) \ + template [[host_name("softmax_precise_" #name)]] [[kernel]] void \ + softmax_single_row( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + uint gid [[thread_position_in_grid]], \ + uint _lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \ softmax_looped( \ const device itype* in, \ @@ -241,9 +240,9 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); +// clang-format off instantiate_softmax(float32, float) instantiate_softmax(float16, half) instantiate_softmax(bfloat16, bfloat16_t) instantiate_softmax_precise(float16, half) -instantiate_softmax_precise(bfloat16, bfloat16_t) -// clang-format on +instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 50b1cfbb6..b2d8ae815 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -11,7 +11,8 @@ using namespace metal; -// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub /////////////////////////////////////////////////////////////////////////////// // Thread-level sort @@ -43,20 +44,18 @@ struct ThreadSort { static METAL_FUNC void sort( thread val_t (&vals)[N_PER_THREAD], thread idx_t (&idxs)[N_PER_THREAD]) { - CompareOp op; MLX_MTL_LOOP_UNROLL - for(short i = 0; i < N_PER_THREAD; ++i) { - MLX_MTL_LOOP_UNROLL - for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) { - if(op(vals[j + 1], vals[j])) { + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); thread_swap(idxs[j + 1], idxs[j]); } } } - } }; @@ -72,25 +71,25 @@ template < short N_PER_THREAD, typename CompareOp> struct BlockMergeSort { - using thread_sort_t = ThreadSort; + using thread_sort_t = + ThreadSort; static METAL_FUNC int merge_partition( const threadgroup val_t* As, const threadgroup val_t* Bs, short A_sz, short B_sz, short sort_md) { - CompareOp op; short A_st = max(0, sort_md - B_sz); short A_ed = min(sort_md, A_sz); - while(A_st < A_ed) { + while (A_st < A_ed) { short md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; - if(op(b, a)) { + if (op(b, a)) { A_ed = md; } else { A_st = md + 1; @@ -98,7 +97,6 @@ struct BlockMergeSort { } return A_ed; - } static METAL_FUNC void merge_step( @@ -110,12 +108,11 @@ struct BlockMergeSort { short B_sz, thread val_t (&vals)[N_PER_THREAD], thread idx_t (&idxs)[N_PER_THREAD]) { - CompareOp op; short a_idx = 0; short b_idx = 0; - for(int i = 0; i < N_PER_THREAD; ++i) { + for (int i = 0; i < N_PER_THREAD; ++i) { auto a = As[a_idx]; auto b = Bs[b_idx]; bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); @@ -126,7 +123,6 @@ struct BlockMergeSort { b_idx += short(pred); a_idx += short(!pred); } - } static METAL_FUNC void sort( @@ -134,32 +130,32 @@ struct BlockMergeSort { threadgroup idx_t* tgp_idxs [[threadgroup(1)]], int size_sorted_axis, uint3 lid [[thread_position_in_threadgroup]]) { - // Get thread location int idx = lid.x * N_PER_THREAD; // Load from shared memory thread val_t thread_vals[N_PER_THREAD]; thread idx_t thread_idxs[N_PER_THREAD]; - for(int i = 0; i < N_PER_THREAD; ++i) { + for (int i = 0; i < N_PER_THREAD; ++i) { thread_vals[i] = tgp_vals[idx + i]; - if(ARG_SORT) { + if (ARG_SORT) { thread_idxs[i] = tgp_idxs[idx + i]; } } - // Per thread sort - if(idx < size_sorted_axis) { + // Per thread sort + if (idx < size_sorted_axis) { thread_sort_t::sort(thread_vals, thread_idxs); } // Do merges using threadgroup memory - for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) { + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { // Update threadgroup memory threadgroup_barrier(mem_flags::mem_threadgroup); - for(int i = 0; i < N_PER_THREAD; ++i) { + for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; - if(ARG_SORT) { + if (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } @@ -167,7 +163,7 @@ struct BlockMergeSort { // Find location in merge step int merge_group = lid.x / merge_threads; - int merge_lane = lid.x % merge_threads; + int merge_lane = lid.x % merge_threads; int sort_sz = N_PER_THREAD * merge_threads; int sort_st = N_PER_THREAD * merge_threads * merge_group; @@ -185,16 +181,11 @@ struct BlockMergeSort { int B_sz = B_ed - B_st; // Find a partition of merge elements - // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // Ci = merge(As[partition:], Bs[sort_md - partition:]) // of size N_PER_THREAD for each merge lane i - // C = [Ci] is sorted + // C = [Ci] is sorted int sort_md = N_PER_THREAD * merge_lane; - int partition = merge_partition( - As, - Bs, - A_sz, - B_sz, - sort_md); + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); As += partition; Bs += sort_md - partition; @@ -202,27 +193,20 @@ struct BlockMergeSort { A_sz -= partition; B_sz -= sort_md - partition; - const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; - const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + const threadgroup idx_t* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup idx_t* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; // Merge starting at the partition and store results in thread registers - merge_step( - As, - Bs, - As_idx, - Bs_idx, - A_sz, - B_sz, - thread_vals, - thread_idxs); - + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); } // Write out to shared memory threadgroup_barrier(mem_flags::mem_threadgroup); - for(int i = 0; i < N_PER_THREAD; ++i) { + for (int i = 0; i < N_PER_THREAD; ++i) { tgp_vals[idx + i] = thread_vals[i]; - if(ARG_SORT) { + if (ARG_SORT) { tgp_idxs[idx + i] = thread_idxs[i]; } } @@ -235,7 +219,7 @@ struct BlockMergeSort { template < typename T, - typename U, + typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, @@ -244,13 +228,13 @@ struct KernelMergeSort { using val_t = T; using idx_t = uint; using block_merge_sort_t = BlockMergeSort< - val_t, + val_t, idx_t, ARG_SORT, - BLOCK_THREADS, + BLOCK_THREADS, N_PER_THREAD, CompareOp>; - + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; static METAL_FUNC void block_sort( @@ -263,15 +247,15 @@ struct KernelMergeSort { threadgroup idx_t* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index inp += tid.y * stride_segment_axis; out += tid.y * stride_segment_axis; // Copy into threadgroup memory - for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { - tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init); - if(ARG_SORT) { + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] + : val_t(CompareOp::init); + if (ARG_SORT) { tgp_idxs[i] = i; } } @@ -284,8 +268,8 @@ struct KernelMergeSort { threadgroup_barrier(mem_flags::mem_threadgroup); // Write output - for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) { - if(ARG_SORT) { + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { out[i * stride_sorted_axis] = tgp_idxs[i]; } else { out[i * stride_sorted_axis] = tgp_vals[i]; @@ -296,7 +280,7 @@ struct KernelMergeSort { template < typename T, - typename U, + typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> @@ -308,12 +292,12 @@ template < const constant int& stride_segment_axis [[buffer(4)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - - using sort_kernel = KernelMergeSort; + using sort_kernel = + KernelMergeSort; using val_t = typename sort_kernel::val_t; using idx_t = typename sort_kernel::idx_t; - if(ARG_SORT) { + if (ARG_SORT) { threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( @@ -339,14 +323,13 @@ template < tid, lid); } - } constant constexpr const int zero_helper = 0; template < typename T, - typename U, + typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> @@ -360,8 +343,8 @@ template < const device size_t* nc_strides [[buffer(6)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - - using sort_kernel = KernelMergeSort; + using sort_kernel = + KernelMergeSort; using val_t = typename sort_kernel::val_t; using idx_t = typename sort_kernel::idx_t; @@ -369,7 +352,7 @@ template < inp += block_idx; out += block_idx; - if(ARG_SORT) { + if (ARG_SORT) { threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( @@ -395,50 +378,55 @@ template < tid, lid); } - } /////////////////////////////////////////////////////////////////////////////// // Instantiations /////////////////////////////////////////////////////////////////////////////// - -#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \ - template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \ - [[kernel]] void block_sort( \ - const device itype* inp [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant int& size_sorted_axis [[buffer(2)]], \ - const constant int& stride_sorted_axis [[buffer(3)]], \ - const constant int& stride_segment_axis [[buffer(4)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \ - [[kernel]] void block_sort_nc( \ - const device itype* inp [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant int& size_sorted_axis [[buffer(2)]], \ - const constant int& stride_sorted_axis [[buffer(3)]], \ - const constant int& nc_dim [[buffer(4)]], \ - const device int* nc_shape [[buffer(5)]], \ - const device size_t* nc_strides [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); +#define instantiate_block_sort( \ + name, itname, itype, otname, otype, arg_sort, bn, tn) \ + template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \ + "_tn" #tn)]] [[kernel]] void \ + block_sort( \ + const device itype* inp [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant int& size_sorted_axis [[buffer(2)]], \ + const constant int& stride_sorted_axis [[buffer(3)]], \ + const constant int& stride_segment_axis [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); \ + template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \ + "_nc")]] [[kernel]] void \ + block_sort_nc( \ + const device itype* inp [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant int& size_sorted_axis [[buffer(2)]], \ + const constant int& stride_sorted_axis [[buffer(3)]], \ + const constant int& nc_dim [[buffer(4)]], \ + const device int* nc_shape [[buffer(5)]], \ + const device size_t* nc_strides [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ - instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn) + instantiate_block_sort( \ + arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn) #define instantiate_block_sort_base(itname, itype, bn, tn) \ - instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn) + instantiate_block_sort( \ + block_merge_sort, itname, itype, itname, itype, false, bn, tn) +// clang-format off #define instantiate_block_sort_tn(itname, itype, bn) \ - instantiate_block_sort_base(itname, itype, bn, 8) \ - instantiate_arg_block_sort_base(itname, itype, bn, 8) + instantiate_block_sort_base(itname, itype, bn, 8) \ + instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on +// clang-format off #define instantiate_block_sort_bn(itname, itype) \ - instantiate_block_sort_tn(itname, itype, 128) \ - instantiate_block_sort_tn(itname, itype, 256) \ - instantiate_block_sort_tn(itname, itype, 512) + instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 256) \ + instantiate_block_sort_tn(itname, itype, 512) instantiate_block_sort_bn(uint8, uint8_t) instantiate_block_sort_bn(uint16, uint16_t) @@ -448,35 +436,35 @@ instantiate_block_sort_bn(int16, int16_t) instantiate_block_sort_bn(int32, int32_t) instantiate_block_sort_bn(float16, half) instantiate_block_sort_bn(float32, float) -instantiate_block_sort_bn(bfloat16, bfloat16_t) - +instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on +// clang-format off #define instantiate_block_sort_long(itname, itype) \ - instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) instantiate_block_sort_long(uint64, uint64_t) -instantiate_block_sort_long(int64, int64_t) +instantiate_block_sort_long(int64, int64_t) // clang-format on -/////////////////////////////////////////////////////////////////////////////// -// Multi block merge sort -/////////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////////// + // Multi block merge sort + /////////////////////////////////////////////////////////////////////////////// -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -struct KernelMultiBlockMergeSort { + template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> + struct KernelMultiBlockMergeSort { using block_merge_sort_t = BlockMergeSort< - val_t, + val_t, idx_t, ARG_SORT, - BLOCK_THREADS, + BLOCK_THREADS, N_PER_THREAD, CompareOp>; - + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; static METAL_FUNC void block_sort( @@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort { threadgroup idx_t* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index int base_idx = tid.x * N_PER_BLOCK; // Copy into threadgroup memory - for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; - tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init); + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : val_t(CompareOp::init); tgp_idxs[i] = idx; } @@ -508,9 +496,9 @@ struct KernelMultiBlockMergeSort { threadgroup_barrier(mem_flags::mem_threadgroup); // Write output - for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; - if(idx < size_sorted_axis) { + if (idx < size_sorted_axis) { out_vals[idx] = tgp_vals[i]; out_idxs[idx] = tgp_idxs[i]; } @@ -523,18 +511,17 @@ struct KernelMultiBlockMergeSort { int A_sz, int B_sz, int sort_md) { - CompareOp op; int A_st = max(0, sort_md - B_sz); int A_ed = min(sort_md, A_sz); - while(A_st < A_ed) { + while (A_st < A_ed) { int md = A_st + (A_ed - A_st) / 2; auto a = As[md]; auto b = Bs[sort_md - 1 - md]; - if(op(b, a)) { + if (op(b, a)) { A_ed = md; } else { A_st = md + 1; @@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort { } return A_ed; - } }; @@ -563,8 +549,12 @@ template < const device size_t* nc_strides [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - - using sort_kernel = KernelMultiBlockMergeSort; + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); inp += block_idx; @@ -575,12 +565,12 @@ template < threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( - inp, - out_vals, - out_idxs, - size_sorted_axis, - stride_sorted_axis, - tgp_vals, + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, tgp_idxs, tid, lid); @@ -592,7 +582,8 @@ template < bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition( +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_partition( device idx_t* block_partitions [[buffer(0)]], const device val_t* dev_vals [[buffer(1)]], const device idx_t* dev_idxs [[buffer(2)]], @@ -601,21 +592,20 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_dims [[threads_per_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, N_PER_THREAD>; - + block_partitions += tid.y * tgp_dims.x; dev_vals += tid.y * size_sorted_axis; dev_idxs += tid.y * size_sorted_axis; // Find location in merge step int merge_group = lid.x / merge_tiles; - int merge_lane = lid.x % merge_tiles; + int merge_lane = lid.x % merge_tiles; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; @@ -627,14 +617,9 @@ template < int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); int partition = sort_kernel::merge_partition( - dev_vals + A_st, - dev_vals + B_st, - A_ed - A_st, - B_ed - B_st, - partition_at); + dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); block_partitions[lid.x] = A_st + partition; - } template < @@ -644,7 +629,8 @@ template < short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge( +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( const device idx_t* block_partitions [[buffer(0)]], const device val_t* dev_vals_in [[buffer(1)]], const device idx_t* dev_idxs_in [[buffer(2)]], @@ -655,20 +641,19 @@ template < const constant int& num_tiles [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, CompareOp>; - + using block_sort_t = typename sort_kernel::block_merge_sort_t; block_partitions += tid.y * (num_tiles + 1); - dev_vals_in += tid.y * size_sorted_axis; - dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; dev_vals_out += tid.y * size_sorted_axis; dev_idxs_out += tid.y * size_sorted_axis; @@ -680,25 +665,29 @@ template < int A_st = block_partitions[block_idx + 0]; int A_ed = block_partitions[block_idx + 1]; - int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st); - int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); - if((block_idx % merge_tiles) == merge_tiles - 1) { - A_ed = min(size_sorted_axis, sort_st + sort_sz/2); + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); B_ed = min(size_sorted_axis, sort_st + sort_sz); } - + int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; // Load from global memory thread val_t thread_vals[N_PER_THREAD]; thread idx_t thread_idxs[N_PER_THREAD]; - for(int i = 0; i < N_PER_THREAD; i++) { + for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; - if(idx < (A_sz + B_sz)) { - thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz]; - thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz]; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; } else { thread_vals[i] = CompareOp::init; thread_idxs[i] = 0; @@ -709,7 +698,7 @@ template < threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup_barrier(mem_flags::mem_threadgroup); - for(int i = 0; i < N_PER_THREAD; i++) { + for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; tgp_vals[idx] = thread_vals[i]; tgp_idxs[idx] = thread_idxs[i]; @@ -720,11 +709,7 @@ template < int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); int A_st_local = block_sort_t::merge_partition( - tgp_vals, - tgp_vals + A_sz, - A_sz, - B_sz, - sort_md_local); + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); int A_ed_local = A_sz; int B_st_local = sort_md_local - A_st_local; @@ -733,7 +718,7 @@ template < int A_sz_local = A_ed_local - A_st_local; int B_sz_local = B_ed_local - B_st_local; - // Do merge + // Do merge block_sort_t::merge_step( tgp_vals + A_st_local, tgp_vals + A_ed_local + B_st_local, @@ -745,61 +730,65 @@ template < thread_idxs); threadgroup_barrier(mem_flags::mem_threadgroup); - for(int i = 0; i < N_PER_THREAD; ++i) { + for (int i = 0; i < N_PER_THREAD; ++i) { int idx = lid.x * N_PER_THREAD; tgp_vals[idx + i] = thread_vals[i]; tgp_idxs[idx + i] = thread_idxs[i]; } - + threadgroup_barrier(mem_flags::mem_threadgroup); // Write output int base_idx = tid.x * sort_kernel::N_PER_BLOCK; - for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) { + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; - if(idx < size_sorted_axis) { + if (idx < size_sorted_axis) { dev_vals_out[idx] = tgp_vals[i]; dev_idxs_out[idx] = tgp_idxs[i]; } } - } -#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \ - template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ - [[kernel]] void mb_block_sort( \ - const device vtype* inp [[buffer(0)]], \ - device vtype* out_vals [[buffer(1)]], \ - device itype* out_idxs [[buffer(2)]], \ - const constant int& size_sorted_axis [[buffer(3)]], \ - const constant int& stride_sorted_axis [[buffer(4)]], \ - const constant int& nc_dim [[buffer(5)]], \ - const device int* nc_shape [[buffer(6)]], \ - const device size_t* nc_strides [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ - [[kernel]] void mb_block_partition( \ - device itype* block_partitions [[buffer(0)]], \ - const device vtype* dev_vals [[buffer(1)]], \ - const device itype* dev_idxs [[buffer(2)]], \ - const constant int& size_sorted_axis [[buffer(3)]], \ - const constant int& merge_tiles [[buffer(4)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 tgp_dims [[threads_per_threadgroup]]); \ - template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ - [[kernel]] void mb_block_merge( \ - const device itype* block_partitions [[buffer(0)]], \ - const device vtype* dev_vals_in [[buffer(1)]], \ - const device itype* dev_idxs_in [[buffer(2)]], \ - device vtype* dev_vals_out [[buffer(3)]], \ - device itype* dev_idxs_out [[buffer(4)]], \ - const constant int& size_sorted_axis [[buffer(5)]], \ - const constant int& merge_tiles [[buffer(6)]], \ - const constant int& num_tiles [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ +#define instantiate_multi_block_sort( \ + vtname, vtype, itname, itype, arg_sort, bn, tn) \ + template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \ + "_tn" #tn)]] [[kernel]] void \ + mb_block_sort( \ + const device vtype* inp [[buffer(0)]], \ + device vtype* out_vals [[buffer(1)]], \ + device itype* out_idxs [[buffer(2)]], \ + const constant int& size_sorted_axis [[buffer(3)]], \ + const constant int& stride_sorted_axis [[buffer(4)]], \ + const constant int& nc_dim [[buffer(5)]], \ + const device int* nc_shape [[buffer(6)]], \ + const device size_t* nc_strides [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); \ + template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \ + "_tn" #tn)]] [[kernel]] void \ + mb_block_partition( \ + device itype * block_partitions [[buffer(0)]], \ + const device vtype* dev_vals [[buffer(1)]], \ + const device itype* dev_idxs [[buffer(2)]], \ + const constant int& size_sorted_axis [[buffer(3)]], \ + const constant int& merge_tiles [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_dims [[threads_per_threadgroup]]); \ + template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \ + "_tn" #tn)]] [[kernel]] void \ + mb_block_merge( \ + const device itype* block_partitions [[buffer(0)]], \ + const device vtype* dev_vals_in [[buffer(1)]], \ + const device itype* dev_idxs_in [[buffer(2)]], \ + device vtype* dev_vals_out [[buffer(3)]], \ + device itype* dev_idxs_out [[buffer(4)]], \ + const constant int& size_sorted_axis [[buffer(5)]], \ + const constant int& merge_tiles [[buffer(6)]], \ + const constant int& num_tiles [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); +// clang-format off #define instantiate_multi_block_sort_base(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) @@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t) instantiate_multi_block_sort_base(int32, int32_t) instantiate_multi_block_sort_base(float16, half) instantiate_multi_block_sort_base(float32, float) -instantiate_multi_block_sort_base(bfloat16, bfloat16_t) +instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on +// clang-format off #define instantiate_multi_block_sort_long(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) instantiate_multi_block_sort_long(uint64, uint64_t) -instantiate_multi_block_sort_long(int64, int64_t) \ No newline at end of file +instantiate_multi_block_sort_long(int64, int64_t) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal index 6f80622ad..39953c2be 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -4,21 +4,23 @@ #include "mlx/backend/metal/kernels/steel/gemm/mma.h" +#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" -#include "mlx/backend/metal/kernels/bf16.h" using namespace metal; -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + int N_CHANNELS = 0, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* C [[buffer(2)]], @@ -28,12 +30,10 @@ template , + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_a>, - // Else go to general loader - typename metal::conditional_t< - // Check if filter size is small enough - SMALL_FILTER, + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, - // Go to small filter specialization - Conv2DInputBlockLoaderSmallFilter< - T, BM, BN, BK, tgp_size, tgp_padding_a>, - - // Else go to large filter generalization - Conv2DInputBlockLoaderLargeFilter< - T, BM, BN, BK, tgp_size, tgp_padding_a> - > - >; + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>, + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>>; // Weight loader using loader_b_t = typename metal::conditional_t< // Check for small channel specialization N_CHANNELS != 0 && N_CHANNELS <= 4, - // Go to small channel specialization - Conv2DWeightBlockLoaderSmallChannels< - T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>, + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader>; - // Else go to general loader - Conv2DWeightBlockLoader - >; - using mma_t = BlockMMA< T, T, @@ -99,12 +117,12 @@ template ; - + threadgroup T As[tgp_mem_size_a]; threadgroup T Bs[tgp_mem_size_b]; const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); const int tid_x = (tid.x) >> gemm_params->swizzle_log; if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { @@ -123,8 +141,10 @@ template M - c_row); short tgp_bn = min(BN, gemm_params->N - c_col); mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm)); - } -#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \ - template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \ - [[kernel]] void implicit_gemm_conv_2d( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device itype* C [[buffer(2)]], \ - const constant MLXConvParams<2>* params [[buffer(3)]], \ - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ +#define instantiate_implicit_conv_2d( \ + name, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + channel_name, \ + n_channels, \ + filter_name, \ + small_filter) \ + template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \ + "_filter_" #filter_name)]] [[kernel]] void \ + implicit_gemm_conv_2d( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* C [[buffer(2)]], \ + const constant MLXConvParams<2>* params [[buffer(3)]], \ + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ - instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ +// clang-format off +#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ - instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on -#define instantiate_implicit_2d_blocks(name, itype) \ +// clang-format off +#define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ - instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on +// clang-format off instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); -instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); \ No newline at end of file +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal index 4f355af23..e902918f9 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal @@ -4,23 +4,25 @@ #include "mlx/backend/metal/kernels/steel/gemm/mma.h" -#include "mlx/backend/metal/kernels/steel/conv/conv.h" -#include "mlx/backend/metal/kernels/steel/conv/params.h" -#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" #include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/conv/conv.h" +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" using namespace metal; using namespace mlx::steel; -template > -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general( +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + typename AccumType = float, + typename Epilogue = TransformNone> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d_general( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device T* C [[buffer(2)]], @@ -33,9 +35,8 @@ template ; + + // Input loader + using loader_a_t = + Conv2DInputBlockLoaderGeneral; // Weight loader - using loader_b_t = Conv2DWeightBlockLoaderGeneral< - T, BM, BN, BK, tgp_size, tgp_padding_b>; - + using loader_b_t = + Conv2DWeightBlockLoaderGeneral; + using mma_t = BlockMMA< T, T, @@ -70,12 +71,12 @@ template ; - + threadgroup T As[tgp_mem_size_a]; threadgroup T Bs[tgp_mem_size_b]; const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); const int tid_x = (tid.x) >> gemm_params->swizzle_log; if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { @@ -103,13 +104,32 @@ template gemm_k_iterations; + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -143,22 +163,24 @@ template adj_out_hw; int hw = cm % jump_params->adj_out_hw; - int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; - int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; - if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - - int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2]; + if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + int offset_cm = n * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements(); + thread const auto& accum = + mma_op.results[i * mma_t::TN + j].thread_elements(); int offset = offset_cm + (j * mma_t::TN_stride); // Apply epilogue and output C @@ -170,40 +192,42 @@ template ( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device itype* C [[buffer(2)]], \ - const constant MLXConvParams<2>* params [[buffer(3)]], \ - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ - const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \ - const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \ - const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ + template \ + [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \ + implicit_gemm_conv_2d_general( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* C [[buffer(2)]], \ + const constant MLXConvParams<2>* params [[buffer(3)]], \ + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \ + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \ + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ - instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) -#define instantiate_implicit_2d_blocks(name, itype) \ +// clang-format off +#define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ - instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on +// clang-format off instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); -instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); \ No newline at end of file +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal index 189042fbe..b7445f2b1 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/utils.h" using namespace metal; using namespace mlx::steel; @@ -11,97 +11,126 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm( - const device T *A [[buffer(0)]], - const device T *B [[buffer(1)]], - device T *D [[buffer(3)]], +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], const constant size_t* batch_strides [[buffer(7)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - - using gemm_kernel = GEMMKernel; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + uint3 lid [[thread_position_in_threadgroup]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; - // Adjust for batch - if(params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - } - - D += params->batch_stride_d * tid.z; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - gemm_kernel::run( - A, B, D, - params, - As, Bs, - simd_lane_id, simd_group_id, tid, lid - ); + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + gemm_kernel::run( + A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid); } /////////////////////////////////////////////////////////////////////////////// // GEMM kernel initializations /////////////////////////////////////////////////////////////////////////////// -#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ - [[kernel]] void gemm( \ - const device itype *A [[buffer(0)]], \ - const device itype *B [[buffer(1)]], \ - device itype *D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned) \ + template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \ + "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \ + "_K_" #kname)]] [[kernel]] void \ + gemm( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ +// clang-format off +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on +// clang-format off instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); -instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file +instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal index ec6efa10e..5989c5602 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal @@ -10,23 +10,24 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// -template > -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm( - const device T *A [[buffer(0)]], - const device T *B [[buffer(1)]], - const device T *C [[buffer(2)]], - device T *D [[buffer(3)]], +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = float, + typename Epilogue = TransformAdd> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2)]], + device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5)]], const constant int* batch_shape [[buffer(6)]], @@ -34,243 +35,304 @@ template ; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; - // Adjust for batch - if(params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned, + AccumType, + Epilogue>; - ulong3 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim); + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; - A += batch_offsets.x; - B += batch_offsets.y; - C += batch_offsets.z; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - C += addmm_params->batch_stride_c * tid.z; - } + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - D += params->batch_stride_d * tid.z; + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; + ulong3 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + A_bstrides, + B_bstrides, + C_bstrides, + params->batch_ndim); - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; + A += batch_offsets.x; + B += batch_offsets.y; + C += batch_offsets.z; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += addmm_params->batch_stride_c * tid.z; + } + + D += params->batch_stride_d * tid.z; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + D += c_row * params->ldd + c_col; + + C += c_row * addmm_params->ldc + c_col * addmm_params->fdc; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - A += transpose_a ? c_row : c_row * params->lda; - B += transpose_b ? c_col * params->ldb : c_col; - D += c_row * params->ldd + c_col; + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); - C += c_row * addmm_params->ldc + c_col * addmm_params->fdc; + threadgroup_barrier(mem_flags::mem_threadgroup); - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + mma_op.mma(As, Bs); + } - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); + // Store results to device memory + mma_op.store_result( + D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); + return; - int gemm_k_iterations = params->gemm_k_iterations_aligned; + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; - const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta); + if (tgp_bm == BM && tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - // Store results to device memory - mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); + mma_op.store_result( + D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); return; + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, + params->ldd, + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, + params->ldd, + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, + params->ldd, + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; - - if (tgp_bm == BM && tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); - return; - - } else if (tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, params->ldd, - C, addmm_params->ldc, addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - - } else if (tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, params->ldd, - C, addmm_params->ldc, addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - return mma_op.store_result_safe( - D, params->ldd, - C, addmm_params->ldc, addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op); - } - } + } } /////////////////////////////////////////////////////////////////////////////// // GEMM kernel initializations /////////////////////////////////////////////////////////////////////////////// -#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \ - template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \ - [[kernel]] void addmm>( \ - const device itype *A [[buffer(0)]], \ - const device itype *B [[buffer(1)]], \ - const device itype *C [[buffer(2)]], \ - device itype *D [[buffer(3)]], \ - const constant GEMMParams* gemm_params [[buffer(4)]], \ - const constant GEMMAddMMParams* params [[buffer(5)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned, \ + ep_name, \ + epilogue) \ + template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \ + "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \ + "_K_" #kname "_" #ep_name)]] [[kernel]] void \ + addmm< \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + mn_aligned, \ + k_aligned, \ + float, \ + epilogue>( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + const device itype* C [[buffer(2)]], \ + device itype* D [[buffer(3)]], \ + const constant GEMMParams* gemm_params [[buffer(4)]], \ + const constant GEMMAddMMParams* params [[buffer(5)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) +// clang-format off +#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ +// clang-format off +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on +// clang-format off instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); -instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file +instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index 522e9653f..24710f5fd 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/utils.h" using namespace metal; using namespace mlx::steel; @@ -11,319 +11,378 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm( - const device T *A [[buffer(0)]], - const device T *B [[buffer(1)]], - device T *D [[buffer(3)]], +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + bool has_operand_mask = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], const constant size_t* batch_strides [[buffer(7)]], - const device bool *out_mask [[buffer(10)]], - const device bool *lhs_mask [[buffer(11)]], - const device bool *rhs_mask [[buffer(12)]], + const device bool* out_mask [[buffer(10)]], + const device bool* lhs_mask [[buffer(11)]], + const device bool* rhs_mask [[buffer(12)]], const constant int* mask_strides [[buffer(13)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; - // Appease the compiler - (void)lid; - - using gemm_kernel = GEMMKernel; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; - if(params->batch_ndim > 1) { - const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim; - out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } - if(has_operand_mask) { - const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim; - const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; + if (params->batch_ndim > 1) { + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim); - - lhs_mask += batch_offsets.x; - rhs_mask += batch_offsets.y; - } - } else { - out_mask += tid.z * batch_strides[2 * params->batch_ndim]; - if(has_operand_mask) { - lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; - rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; - } - } - - // Adjust for batch - if(params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + if (has_operand_mask) { + const constant size_t* mask_strides_lhs = + mask_batch_strides + params->batch_ndim; + const constant size_t* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; } - - D += params->batch_stride_d * tid.z; + } else { + out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } + } - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - A += transpose_a ? c_row : c_row * params->lda; - B += transpose_b ? c_col * params->ldb : c_col; - D += c_row * params->ldd + c_col; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + A += batch_offsets.x; + B += batch_offsets.y; - bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } - // Write zeros and return - if(!mask_out) { - constexpr short tgp_size = WM * WN * 32; - constexpr short vec_size = 4; + D += params->batch_stride_d * tid.z; - // Tile threads in threadgroup - constexpr short TN = BN / vec_size; - constexpr short TM = tgp_size / TN; + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; - const short thread_idx = simd_group_id * 32 + simd_lane_id; - const short bi = thread_idx / TN; - const short bj = vec_size * (thread_idx % TN); + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + D += c_row * params->ldd + c_col; - D += bi * params->ldd + bj; + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - for (short ti = 0; ti < BM; ti += TM) { - STEEL_PRAGMA_UNROLL - for(short j = 0; j < vec_size; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } else { - short jmax = tgp_bn - bj; - jmax = jmax < vec_size ? jmax : vec_size; - for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { - for(short j = 0; j < jmax; j++) { - D[ti * params->ldd + j] = T(0.); - } + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); } } - - return; + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); - // Prepare threadgroup mma operation - thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + // Loop tail + if (!K_aligned) { + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - int gemm_k_iterations = params->gemm_k_iterations_aligned; + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Prepare threadgroup loading operations - thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); - if(!has_operand_mask || - (lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + mma_op.mma(As, Bs); + } + } - // Load elements into threadgroup + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + if (M_aligned) { loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - - if(!has_operand_mask || - (lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { - - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - + } else { loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); } - // Store results to device memory + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + if (M_aligned && N_aligned) { mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short lbk = params->K - params->gemm_k_iterations_aligned * BK; - - bool M_aligned = (tgp_bm == BM); - bool N_aligned = (tgp_bn == BN); - - short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if(!has_operand_mask || - (lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { - - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - if (!K_aligned) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if(!has_operand_mask || - (lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { - - short2 tile_dims_A_last = - transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); - short2 tile_dims_B_last = - transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - - loader_a.load_safe(tile_dims_A_last); - loader_b.load_safe(tile_dims_B_last); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - - } - } - - if(M_aligned && N_aligned) { - mma_op.store_result(D, params->ldd); - } else { - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); } + } } /////////////////////////////////////////////////////////////////////////////// // GEMM kernel initializations /////////////////////////////////////////////////////////////////////////////// -#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \ - template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \ - [[kernel]] void block_masked_gemm( \ - const device itype *A [[buffer(0)]], \ - const device itype *B [[buffer(1)]], \ - device itype *D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - const device bool *out_mask [[buffer(10)]], \ - const device bool *lhs_mask [[buffer(11)]], \ - const device bool *rhs_mask [[buffer(12)]], \ - const constant int* mask_strides [[buffer(13)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned, \ + omname, \ + op_mask) \ + template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ + "_MN_" #aname "_K_" #kname \ + "_op_mask_" #omname)]] [[kernel]] void \ + block_masked_gemm< \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + mn_aligned, \ + k_aligned, \ + op_mask>( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const device bool* out_mask [[buffer(10)]], \ + const device bool* lhs_mask [[buffer(11)]], \ + const device bool* rhs_mask [[buffer(12)]], \ + const constant int* mask_strides [[buffer(13)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); +// clang-format off #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ +// clang-format off +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on +// clang-format off instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); -instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file +instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index 873f5faf1..f99149569 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -10,81 +10,95 @@ using namespace mlx::steel; // GEMM kernels /////////////////////////////////////////////////////////////////////////////// -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( - const device T *A [[buffer(0)]], - const device T *B [[buffer(1)]], - device U *C [[buffer(2)]], +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], const constant GEMMSpiltKParams* params [[buffer(3)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { + uint3 lid [[thread_position_in_threadgroup]]) { + (void)lid; - (void)lid; - - using gemm_kernel = GEMMKernel; - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + using gemm_kernel = GEMMKernel< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; - const int tid_x = tid.x; - const int tid_y = tid.y; - const int tid_z = tid.z; + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int k_start = params->split_k_partition_size * tid_z; + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } - A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda); - B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb); - C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col); + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + A += transpose_a ? (c_row + k_start * params->lda) + : (k_start + c_row * params->lda); + B += transpose_b ? (k_start + c_col * params->ldb) + : (c_col + k_start * params->ldb); + C += (params->split_k_partition_stride * tid_z) + + (c_row * params->ldc + c_col); - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - int gemm_k_iterations = params->gemm_k_iterations_aligned; + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K % BK; + int gemm_k_iterations = params->gemm_k_iterations_aligned; - if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else if (tgp_bn == BN) { - gemm_kernel::gemm_loop( + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( As, Bs, gemm_k_iterations, @@ -95,37 +109,38 @@ template {}); - } else if (tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - if ((tid_z + 1) == (params->split_k_partitions)) { - int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK; - if(!K_aligned || gemm_k_iter_remaining > 0) + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = + (params->K - (k_start + params->split_k_partition_size)) / BK; + if (!K_aligned || gemm_k_iter_remaining > 0) gemm_kernel::gemm_loop( As, Bs, @@ -137,69 +152,102 @@ template {}); - } + } - if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - mma_op.store_result(C, params->ldc); - } else { - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); - } + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } } /////////////////////////////////////////////////////////////////////////////// // GEMM kernel initializations /////////////////////////////////////////////////////////////////////////////// -#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ - [[kernel]] void gemm_splitk( \ - const device itype *A [[buffer(0)]], \ - const device itype *B [[buffer(1)]], \ - device otype *C [[buffer(2)]], \ - const constant GEMMSpiltKParams* params [[buffer(3)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned) \ + template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ + "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ + gemm_splitk< \ + itype, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + mn_aligned, \ + k_aligned>( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device otype* C [[buffer(2)]], \ + const constant GEMMSpiltKParams* params [[buffer(3)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ +// clang-format off +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on -#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on -#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on +// clang-format off instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); -instantiate_gemm_shapes_helper(float32, float, float32, float); +instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on /////////////////////////////////////////////////////////////////////////////// -// Split k accumulation kernel +// Split k accumulation kernel /////////////////////////////////////////////////////////////////////////////// -template > +template < + typename AccT, + typename OutT, + typename Epilogue = TransformNone> [[kernel]] void gemm_splitk_accum( - const device AccT *C_split [[buffer(0)]], - device OutT *D [[buffer(1)]], + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], const constant int& k_partitions [[buffer(2)]], const constant int& partition_stride [[buffer(3)]], const constant int& ldd [[buffer(4)]], uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C D += gid.x + gid.y * ldd; C_split += gid.x + gid.y * ldd; @@ -207,32 +255,31 @@ template > +template < + typename AccT, + typename OutT, + typename Epilogue = TransformAxpby> [[kernel]] void gemm_splitk_accum_axpby( - const device AccT *C_split [[buffer(0)]], - device OutT *D [[buffer(1)]], + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], const constant int& k_partitions [[buffer(2)]], const constant int& partition_stride [[buffer(3)]], const constant int& ldd [[buffer(4)]], - const device OutT *C [[buffer(5)]], + const device OutT* C [[buffer(5)]], const constant int& ldc [[buffer(6)]], const constant int& fdc [[buffer(7)]], const constant float& alpha [[buffer(8)]], const constant float& beta [[buffer(9)]], uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C C += gid.x * fdc + gid.y * ldc; D += gid.x + gid.y * ldd; @@ -241,40 +288,42 @@ template ( \ - const device atype *C_split [[buffer(0)]], \ - device otype *D [[buffer(1)]], \ - const constant int& k_partitions [[buffer(2)]], \ - const constant int& partition_stride [[buffer(3)]], \ - const constant int& ldd [[buffer(4)]], \ - uint2 gid [[thread_position_in_grid]]); \ - template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \ - [[kernel]] void gemm_splitk_accum_axpby( \ - const device atype *C_split [[buffer(0)]], \ - device otype *D [[buffer(1)]], \ - const constant int& k_partitions [[buffer(2)]], \ - const constant int& partition_stride [[buffer(3)]], \ - const constant int& ldd [[buffer(4)]], \ - const device otype *C [[buffer(5)]], \ - const constant int& ldc [[buffer(6)]], \ - const constant int& fdc [[buffer(7)]], \ - const constant float& alpha [[buffer(8)]], \ - const constant float& beta [[buffer(9)]], \ +#define instantiate_accum(oname, otype, aname, atype) \ + template [[host_name("steel_gemm_splitk_accum_" #oname \ + "_" #aname)]] [[kernel]] void \ + gemm_splitk_accum( \ + const device atype* C_split [[buffer(0)]], \ + device otype* D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + uint2 gid [[thread_position_in_grid]]); \ + template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \ + "_axpby")]] [[kernel]] void \ + gemm_splitk_accum_axpby( \ + const device atype* C_split [[buffer(0)]], \ + device otype* D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + const device otype* C [[buffer(5)]], \ + const constant int& ldc [[buffer(6)]], \ + const constant int& fdc [[buffer(7)]], \ + const constant float& alpha [[buffer(8)]], \ + const constant float& beta [[buffer(9)]], \ uint2 gid [[thread_position_in_grid]]); +// clang-format off instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(float16, half, float32, float); -instantiate_accum(float32, float, float32, float); \ No newline at end of file +instantiate_accum(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index c351bed17..11fd87a91 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -3,9 +3,9 @@ #include #include -#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/ternary.h" +#include "mlx/backend/metal/kernels/utils.h" template [[kernel]] void ternary_op_v( @@ -65,7 +65,8 @@ template auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); auto c_idx = elem_to_loc_3(index, c_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } @@ -81,8 +82,10 @@ template constant const size_t c_strides[DIM], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + auto idx = + elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); } @@ -99,103 +102,104 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); + auto idx = + elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); } -#define instantiate_ternary_v(name, type, op) \ - template [[host_name(name)]] \ - [[kernel]] void ternary_op_v( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - uint index [[thread_position_in_grid]]); \ +#define instantiate_ternary_v(name, type, op) \ + template [[host_name(name)]] [[kernel]] void ternary_op_v( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + uint index [[thread_position_in_grid]]); -#define instantiate_ternary_g(name, type, op) \ - template [[host_name(name)]] \ - [[kernel]] void ternary_op_g( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const size_t* c_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ +#define instantiate_ternary_g(name, type, op) \ + template [[host_name(name)]] [[kernel]] void ternary_op_g( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const size_t* c_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); -#define instantiate_ternary_g_dim(name, type, op, dims) \ - template [[host_name(name "_" #dims)]] \ - [[kernel]] void ternary_op_g_nd( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - constant const size_t c_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ +#define instantiate_ternary_g_dim(name, type, op, dims) \ + template [[host_name(name "_" #dims)]] [[kernel]] void \ + ternary_op_g_nd( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + constant const size_t c_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); -#define instantiate_ternary_g_nd(name, type, op) \ - template [[host_name(name "_1")]] \ - [[kernel]] void ternary_op_g_nd1( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t& a_strides, \ - constant const size_t& b_strides, \ - constant const size_t& c_strides, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] \ - [[kernel]] void ternary_op_g_nd2( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - constant const size_t c_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] \ - [[kernel]] void ternary_op_g_nd3( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - constant const size_t c_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_ternary_g_dim(name, type, op, 4) \ - instantiate_ternary_g_dim(name, type, op, 5) \ +#define instantiate_ternary_g_nd(name, type, op) \ + template [[host_name(name "_1")]] [[kernel]] void \ + ternary_op_g_nd1( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t& a_strides, \ + constant const size_t& b_strides, \ + constant const size_t& c_strides, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] [[kernel]] void \ + ternary_op_g_nd2( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + constant const size_t c_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] [[kernel]] void \ + ternary_op_g_nd3( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + constant const size_t c_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_ternary_g_dim(name, type, op, 4) \ + instantiate_ternary_g_dim(name, type, op, 5) +// clang-format off #define instantiate_ternary_all(name, tname, type, op) \ - instantiate_ternary_v("v" #name #tname, type, op) \ - instantiate_ternary_g("g" #name #tname, type, op) \ - instantiate_ternary_g_nd("g" #name #tname, type, op) \ + instantiate_ternary_v("v" #name #tname, type, op) \ + instantiate_ternary_g("g" #name #tname, type, op) \ + instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on -#define instantiate_ternary_types(name, op) \ - instantiate_ternary_all(name, bool_, bool, op) \ - instantiate_ternary_all(name, uint8, uint8_t, op) \ - instantiate_ternary_all(name, uint16, uint16_t, op) \ - instantiate_ternary_all(name, uint32, uint32_t, op) \ - instantiate_ternary_all(name, uint64, uint64_t, op) \ - instantiate_ternary_all(name, int8, int8_t, op) \ - instantiate_ternary_all(name, int16, int16_t, op) \ - instantiate_ternary_all(name, int32, int32_t, op) \ - instantiate_ternary_all(name, int64, int64_t, op) \ - instantiate_ternary_all(name, float16, half, op) \ - instantiate_ternary_all(name, float32, float, op) \ +// clang-format off +#define instantiate_ternary_types(name, op) \ + instantiate_ternary_all(name, bool_, bool, op) \ + instantiate_ternary_all(name, uint8, uint8_t, op) \ + instantiate_ternary_all(name, uint16, uint16_t, op) \ + instantiate_ternary_all(name, uint32, uint32_t, op) \ + instantiate_ternary_all(name, uint64, uint64_t, op) \ + instantiate_ternary_all(name, int8, int8_t, op) \ + instantiate_ternary_all(name, int16, int16_t, op) \ + instantiate_ternary_all(name, int32, int32_t, op) \ + instantiate_ternary_all(name, int64, int64_t, op) \ + instantiate_ternary_all(name, float16, half, op) \ + instantiate_ternary_all(name, float32, float, op) \ instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ - instantiate_ternary_all(name, complex64, complex64_t, op) \ + instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on -instantiate_ternary_types(select, Select) +instantiate_ternary_types(select, Select) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 69b5580c5..e9b52d58d 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -22,44 +22,46 @@ template out[index] = Op()(in[idx]); } -#define instantiate_unary_v(name, type, op) \ - template [[host_name(name)]] \ - [[kernel]] void unary_op_v( \ - device const type* in, \ - device type* out, \ +#define instantiate_unary_v(name, type, op) \ + template [[host_name(name)]] [[kernel]] void unary_op_v( \ + device const type* in, \ + device type* out, \ uint index [[thread_position_in_grid]]); -#define instantiate_unary_g(name, type, op) \ - template [[host_name(name)]] \ - [[kernel]] void unary_op_g( \ - device const type* in, \ - device type* out, \ - device const int* in_shape, \ - device const size_t* in_strides, \ - device const int& ndim, \ +#define instantiate_unary_g(name, type, op) \ + template [[host_name(name)]] [[kernel]] void unary_op_g( \ + device const type* in, \ + device type* out, \ + device const int* in_shape, \ + device const size_t* in_strides, \ + device const int& ndim, \ uint index [[thread_position_in_grid]]); +// clang-format off #define instantiate_unary_all(name, tname, type, op) \ - instantiate_unary_v("v" #name #tname, type, op) \ - instantiate_unary_g("g" #name #tname, type, op) + instantiate_unary_v("v" #name #tname, type, op) \ + instantiate_unary_g("g" #name #tname, type, op) // clang-format on -#define instantiate_unary_float(name, op) \ - instantiate_unary_all(name, float16, half, op) \ - instantiate_unary_all(name, float32, float, op) \ - instantiate_unary_all(name, bfloat16, bfloat16_t, op) \ +// clang-format off +#define instantiate_unary_float(name, op) \ + instantiate_unary_all(name, float16, half, op) \ + instantiate_unary_all(name, float32, float, op) \ + instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on -#define instantiate_unary_types(name, op) \ - instantiate_unary_all(name, bool_, bool, op) \ - instantiate_unary_all(name, uint8, uint8_t, op) \ +// clang-format off +#define instantiate_unary_types(name, op) \ + instantiate_unary_all(name, bool_, bool, op) \ + instantiate_unary_all(name, uint8, uint8_t, op) \ instantiate_unary_all(name, uint16, uint16_t, op) \ instantiate_unary_all(name, uint32, uint32_t, op) \ instantiate_unary_all(name, uint64, uint64_t, op) \ - instantiate_unary_all(name, int8, int8_t, op) \ - instantiate_unary_all(name, int16, int16_t, op) \ - instantiate_unary_all(name, int32, int32_t, op) \ - instantiate_unary_all(name, int64, int64_t, op) \ - instantiate_unary_float(name, op) + instantiate_unary_all(name, int8, int8_t, op) \ + instantiate_unary_all(name, int16, int16_t, op) \ + instantiate_unary_all(name, int32, int32_t, op) \ + instantiate_unary_all(name, int64, int64_t, op) \ + instantiate_unary_float(name, op) // clang-format on +// clang-format off instantiate_unary_types(abs, Abs) instantiate_unary_float(arccos, ArcCos) instantiate_unary_float(arccosh, ArcCosh) @@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan) instantiate_unary_all(tanh, complex64, complex64_t, Tanh) instantiate_unary_all(round, complex64, complex64_t, Round) -instantiate_unary_all(lnot, bool_, bool, LogicalNot) +instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on diff --git a/mlx/device.h b/mlx/device.h index e11edf793..2a09195f4 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -13,7 +13,7 @@ struct Device { static constexpr DeviceType cpu = DeviceType::cpu; static constexpr DeviceType gpu = DeviceType::gpu; - Device(DeviceType type, int index = 0) : type(type), index(index){}; + Device(DeviceType type, int index = 0) : type(type), index(index) {}; DeviceType type; int index; diff --git a/mlx/dtype.h b/mlx/dtype.h index 007b09d74..1818837e6 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -51,7 +51,7 @@ struct Dtype { Val val; const uint8_t size; - constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; + constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}; constexpr operator Val() const { return val; }; diff --git a/mlx/event.h b/mlx/event.h index 4fee164eb..03f23c2da 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -10,7 +10,7 @@ namespace mlx::core { class Event { public: - Event(){}; + Event() {}; Event(const Stream& steam); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 0bf5b511a..013365f11 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -12,7 +12,7 @@ class Custom : public Primitive { explicit Custom( Stream stream, std::function(std::vector)> fallback) - : Primitive(stream), fallback_(fallback){}; + : Primitive(stream), fallback_(fallback) {}; virtual std::pair, std::vector> vmap( const std::vector& inputs, @@ -39,7 +39,7 @@ class RMSNorm : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps){}; + : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -68,7 +68,7 @@ class RMSNormVJP : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps){}; + : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -91,7 +91,7 @@ class LayerNorm : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps){}; + : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -120,7 +120,7 @@ class LayerNormVJP : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps){}; + : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -154,7 +154,7 @@ class RoPE : public Custom { base_(base), scale_(scale), offset_(offset), - forward_(forward){}; + forward_(forward) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -189,7 +189,7 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback, const float scale, const bool needs_mask) - : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){}; + : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { diff --git a/mlx/primitives.h b/mlx/primitives.h index 390763b93..1e61bde5b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -154,7 +154,7 @@ class UnaryPrimitive : public Primitive { class Abs : public UnaryPrimitive { public: - explicit Abs(Stream stream) : UnaryPrimitive(stream){}; + explicit Abs(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -171,7 +171,7 @@ class Abs : public UnaryPrimitive { class Add : public UnaryPrimitive { public: - explicit Add(Stream stream) : UnaryPrimitive(stream){}; + explicit Add(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -189,7 +189,7 @@ class Add : public UnaryPrimitive { class AddMM : public UnaryPrimitive { public: explicit AddMM(Stream stream, float alpha, float beta) - : UnaryPrimitive(stream), alpha_(alpha), beta_(beta){}; + : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -213,7 +213,7 @@ class AddMM : public UnaryPrimitive { class Arange : public UnaryPrimitive { public: explicit Arange(Stream stream, double start, double stop, double step) - : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){}; + : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -231,7 +231,7 @@ class Arange : public UnaryPrimitive { class ArcCos : public UnaryPrimitive { public: - explicit ArcCos(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -248,7 +248,7 @@ class ArcCos : public UnaryPrimitive { class ArcCosh : public UnaryPrimitive { public: - explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -265,7 +265,7 @@ class ArcCosh : public UnaryPrimitive { class ArcSin : public UnaryPrimitive { public: - explicit ArcSin(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -282,7 +282,7 @@ class ArcSin : public UnaryPrimitive { class ArcSinh : public UnaryPrimitive { public: - explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -299,7 +299,7 @@ class ArcSinh : public UnaryPrimitive { class ArcTan : public UnaryPrimitive { public: - explicit ArcTan(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -316,7 +316,7 @@ class ArcTan : public UnaryPrimitive { class ArcTanh : public UnaryPrimitive { public: - explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){}; + explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -334,7 +334,7 @@ class ArcTanh : public UnaryPrimitive { class ArgPartition : public UnaryPrimitive { public: explicit ArgPartition(Stream stream, int kth, int axis) - : UnaryPrimitive(stream), kth_(kth), axis_(axis){}; + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -359,7 +359,7 @@ class ArgReduce : public UnaryPrimitive { }; explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) - : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -380,7 +380,7 @@ class ArgReduce : public UnaryPrimitive { class ArgSort : public UnaryPrimitive { public: explicit ArgSort(Stream stream, int axis) - : UnaryPrimitive(stream), axis_(axis){}; + : UnaryPrimitive(stream), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -399,7 +399,7 @@ class ArgSort : public UnaryPrimitive { class AsType : public UnaryPrimitive { public: explicit AsType(Stream stream, Dtype dtype) - : UnaryPrimitive(stream), dtype_(dtype){}; + : UnaryPrimitive(stream), dtype_(dtype) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -426,7 +426,7 @@ class AsStrided : public UnaryPrimitive { : UnaryPrimitive(stream), shape_(std::move(shape)), strides_(std::move(strides)), - offset_(offset){}; + offset_(offset) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -448,7 +448,7 @@ class BitwiseBinary : public UnaryPrimitive { enum Op { And, Or, Xor, LeftShift, RightShift }; explicit BitwiseBinary(Stream stream, Op op) - : UnaryPrimitive(stream), op_(op){}; + : UnaryPrimitive(stream), op_(op) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -465,7 +465,7 @@ class BitwiseBinary : public UnaryPrimitive { class BlockMaskedMM : public UnaryPrimitive { public: explicit BlockMaskedMM(Stream stream, int block_size) - : UnaryPrimitive(stream), block_size_(block_size){}; + : UnaryPrimitive(stream), block_size_(block_size) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -488,7 +488,7 @@ class BlockMaskedMM : public UnaryPrimitive { class Broadcast : public UnaryPrimitive { public: explicit Broadcast(Stream stream, const std::vector& shape) - : UnaryPrimitive(stream), shape_(shape){}; + : UnaryPrimitive(stream), shape_(shape) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -506,7 +506,7 @@ class Broadcast : public UnaryPrimitive { class Ceil : public UnaryPrimitive { public: - explicit Ceil(Stream stream) : UnaryPrimitive(stream){}; + explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -567,7 +567,7 @@ class Compiled : public Primitive { class Concatenate : public UnaryPrimitive { public: explicit Concatenate(Stream stream, int axis) - : UnaryPrimitive(stream), axis_(axis){}; + : UnaryPrimitive(stream), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -599,7 +599,7 @@ class Convolution : public UnaryPrimitive { kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), groups_(groups), - flip_(flip){}; + flip_(flip) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -626,7 +626,7 @@ class Convolution : public UnaryPrimitive { class Copy : public UnaryPrimitive { public: - explicit Copy(Stream stream) : UnaryPrimitive(stream){}; + explicit Copy(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -643,7 +643,7 @@ class Copy : public UnaryPrimitive { class Cos : public UnaryPrimitive { public: - explicit Cos(Stream stream) : UnaryPrimitive(stream){}; + explicit Cos(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -660,7 +660,7 @@ class Cos : public UnaryPrimitive { class Cosh : public UnaryPrimitive { public: - explicit Cosh(Stream stream) : UnaryPrimitive(stream){}; + explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -731,7 +731,7 @@ class Depends : public Primitive { class Divide : public UnaryPrimitive { public: - explicit Divide(Stream stream) : UnaryPrimitive(stream){}; + explicit Divide(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -748,7 +748,7 @@ class Divide : public UnaryPrimitive { class DivMod : public Primitive { public: - explicit DivMod(Stream stream) : Primitive(stream){}; + explicit DivMod(Stream stream) : Primitive(stream) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -770,7 +770,7 @@ class DivMod : public Primitive { class Select : public UnaryPrimitive { public: - explicit Select(Stream stream) : UnaryPrimitive(stream){}; + explicit Select(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -787,7 +787,7 @@ class Select : public UnaryPrimitive { class Remainder : public UnaryPrimitive { public: - explicit Remainder(Stream stream) : UnaryPrimitive(stream){}; + explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -805,7 +805,7 @@ class Remainder : public UnaryPrimitive { class Equal : public UnaryPrimitive { public: explicit Equal(Stream stream, bool equal_nan = false) - : UnaryPrimitive(stream), equal_nan_(equal_nan){}; + : UnaryPrimitive(stream), equal_nan_(equal_nan) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -830,7 +830,7 @@ class Equal : public UnaryPrimitive { class Erf : public UnaryPrimitive { public: - explicit Erf(Stream stream) : UnaryPrimitive(stream){}; + explicit Erf(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -847,7 +847,7 @@ class Erf : public UnaryPrimitive { class ErfInv : public UnaryPrimitive { public: - explicit ErfInv(Stream stream) : UnaryPrimitive(stream){}; + explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -864,7 +864,7 @@ class ErfInv : public UnaryPrimitive { class Exp : public UnaryPrimitive { public: - explicit Exp(Stream stream) : UnaryPrimitive(stream){}; + explicit Exp(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -881,7 +881,7 @@ class Exp : public UnaryPrimitive { class Expm1 : public UnaryPrimitive { public: - explicit Expm1(Stream stream) : UnaryPrimitive(stream){}; + explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -902,7 +902,7 @@ class FFT : public UnaryPrimitive { const std::vector& axes, bool inverse, bool real) - : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){}; + : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -923,7 +923,7 @@ class FFT : public UnaryPrimitive { class Floor : public UnaryPrimitive { public: - explicit Floor(Stream stream) : UnaryPrimitive(stream){}; + explicit Floor(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -940,7 +940,7 @@ class Floor : public UnaryPrimitive { class Full : public UnaryPrimitive { public: - explicit Full(Stream stream) : UnaryPrimitive(stream){}; + explicit Full(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -960,7 +960,7 @@ class Gather : public UnaryPrimitive { Stream stream, const std::vector& axes, const std::vector& slice_sizes) - : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; + : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -978,7 +978,7 @@ class Gather : public UnaryPrimitive { class Greater : public UnaryPrimitive { public: - explicit Greater(Stream stream) : UnaryPrimitive(stream){}; + explicit Greater(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -995,7 +995,7 @@ class Greater : public UnaryPrimitive { class GreaterEqual : public UnaryPrimitive { public: - explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){}; + explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1012,7 +1012,7 @@ class GreaterEqual : public UnaryPrimitive { class Less : public UnaryPrimitive { public: - explicit Less(Stream stream) : UnaryPrimitive(stream){}; + explicit Less(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1029,7 +1029,7 @@ class Less : public UnaryPrimitive { class LessEqual : public UnaryPrimitive { public: - explicit LessEqual(Stream stream) : UnaryPrimitive(stream){}; + explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1054,7 +1054,7 @@ class Load : public UnaryPrimitive { : UnaryPrimitive(stream), reader_(reader), offset_(offset), - swap_endianness_(swap_endianness){}; + swap_endianness_(swap_endianness) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1073,7 +1073,7 @@ class Log : public UnaryPrimitive { enum Base { two, ten, e }; explicit Log(Stream stream, Base base) - : UnaryPrimitive(stream), base_(base){}; + : UnaryPrimitive(stream), base_(base) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1104,7 +1104,7 @@ class Log : public UnaryPrimitive { class Log1p : public UnaryPrimitive { public: - explicit Log1p(Stream stream) : UnaryPrimitive(stream){}; + explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1120,7 +1120,7 @@ class Log1p : public UnaryPrimitive { class LogicalNot : public UnaryPrimitive { public: - explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){}; + explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1137,7 +1137,7 @@ class LogicalNot : public UnaryPrimitive { class LogicalAnd : public UnaryPrimitive { public: - explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){}; + explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1154,7 +1154,7 @@ class LogicalAnd : public UnaryPrimitive { class LogicalOr : public UnaryPrimitive { public: - explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){}; + explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1171,7 +1171,7 @@ class LogicalOr : public UnaryPrimitive { class LogAddExp : public UnaryPrimitive { public: - explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){}; + explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1188,7 +1188,7 @@ class LogAddExp : public UnaryPrimitive { class Matmul : public UnaryPrimitive { public: - explicit Matmul(Stream stream) : UnaryPrimitive(stream){}; + explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1206,7 +1206,7 @@ class Matmul : public UnaryPrimitive { class Maximum : public UnaryPrimitive { public: - explicit Maximum(Stream stream) : UnaryPrimitive(stream){}; + explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1223,7 +1223,7 @@ class Maximum : public UnaryPrimitive { class Minimum : public UnaryPrimitive { public: - explicit Minimum(Stream stream) : UnaryPrimitive(stream){}; + explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1240,7 +1240,7 @@ class Minimum : public UnaryPrimitive { class Multiply : public UnaryPrimitive { public: - explicit Multiply(Stream stream) : UnaryPrimitive(stream){}; + explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1257,7 +1257,7 @@ class Multiply : public UnaryPrimitive { class Negative : public UnaryPrimitive { public: - explicit Negative(Stream stream) : UnaryPrimitive(stream){}; + explicit Negative(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1274,7 +1274,7 @@ class Negative : public UnaryPrimitive { class NotEqual : public UnaryPrimitive { public: - explicit NotEqual(Stream stream) : UnaryPrimitive(stream){}; + explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1330,7 +1330,7 @@ class Pad : public UnaryPrimitive { : UnaryPrimitive(stream), axes_(axes), low_pad_size_(low_pad_size), - high_pad_size_(high_pad_size){}; + high_pad_size_(high_pad_size) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1351,7 +1351,7 @@ class Pad : public UnaryPrimitive { class Partition : public UnaryPrimitive { public: explicit Partition(Stream stream, int kth, int axis) - : UnaryPrimitive(stream), kth_(kth), axis_(axis){}; + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1371,7 +1371,7 @@ class Partition : public UnaryPrimitive { class Power : public UnaryPrimitive { public: - explicit Power(Stream stream) : UnaryPrimitive(stream){}; + explicit Power(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1396,7 +1396,7 @@ class QuantizedMatmul : public UnaryPrimitive { : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose){}; + transpose_(transpose) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1417,7 +1417,7 @@ class QuantizedMatmul : public UnaryPrimitive { class RandomBits : public UnaryPrimitive { public: explicit RandomBits(Stream stream, const std::vector& shape, int width) - : UnaryPrimitive(stream), shape_(shape), width_(width){}; + : UnaryPrimitive(stream), shape_(shape), width_(width) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1436,7 +1436,7 @@ class RandomBits : public UnaryPrimitive { class Reshape : public UnaryPrimitive { public: explicit Reshape(Stream stream, const std::vector& shape) - : UnaryPrimitive(stream), shape_(shape){}; + : UnaryPrimitive(stream), shape_(shape) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1468,7 +1468,7 @@ class Reduce : public UnaryPrimitive { Stream stream, ReduceType reduce_type, const std::vector& axes) - : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1517,7 +1517,7 @@ class Reduce : public UnaryPrimitive { class Round : public UnaryPrimitive { public: - explicit Round(Stream stream) : UnaryPrimitive(stream){}; + explicit Round(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1546,7 +1546,7 @@ class Scan : public UnaryPrimitive { reduce_type_(reduce_type), axis_(axis), reverse_(reverse), - inclusive_(inclusive){}; + inclusive_(inclusive) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1591,7 +1591,7 @@ class Scatter : public UnaryPrimitive { Stream stream, ReduceType reduce_type, const std::vector& axes) - : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1626,7 +1626,7 @@ class Scatter : public UnaryPrimitive { class Sigmoid : public UnaryPrimitive { public: - explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){}; + explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1643,7 +1643,7 @@ class Sigmoid : public UnaryPrimitive { class Sign : public UnaryPrimitive { public: - explicit Sign(Stream stream) : UnaryPrimitive(stream){}; + explicit Sign(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1660,7 +1660,7 @@ class Sign : public UnaryPrimitive { class Sin : public UnaryPrimitive { public: - explicit Sin(Stream stream) : UnaryPrimitive(stream){}; + explicit Sin(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1677,7 +1677,7 @@ class Sin : public UnaryPrimitive { class Sinh : public UnaryPrimitive { public: - explicit Sinh(Stream stream) : UnaryPrimitive(stream){}; + explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1702,7 +1702,7 @@ class Slice : public UnaryPrimitive { : UnaryPrimitive(stream), start_indices_(start_indices), end_indices_(end_indices), - strides_(strides){}; + strides_(strides) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1738,7 +1738,7 @@ class SliceUpdate : public UnaryPrimitive { : UnaryPrimitive(stream), start_indices_(start_indices), end_indices_(end_indices), - strides_(strides){}; + strides_(strides) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1761,7 +1761,7 @@ class SliceUpdate : public UnaryPrimitive { class Softmax : public UnaryPrimitive { public: explicit Softmax(Stream stream, bool precise) - : UnaryPrimitive(stream), precise_(precise){}; + : UnaryPrimitive(stream), precise_(precise) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1781,7 +1781,7 @@ class Softmax : public UnaryPrimitive { class Sort : public UnaryPrimitive { public: explicit Sort(Stream stream, int axis) - : UnaryPrimitive(stream), axis_(axis){}; + : UnaryPrimitive(stream), axis_(axis) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1801,7 +1801,7 @@ class Sort : public UnaryPrimitive { class Split : public Primitive { public: explicit Split(Stream stream, const std::vector& indices, int axis) - : Primitive(stream), indices_(indices), axis_(axis){}; + : Primitive(stream), indices_(indices), axis_(axis) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -1822,7 +1822,7 @@ class Split : public Primitive { class Square : public UnaryPrimitive { public: - explicit Square(Stream stream) : UnaryPrimitive(stream){}; + explicit Square(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1840,7 +1840,7 @@ class Square : public UnaryPrimitive { class Sqrt : public UnaryPrimitive { public: explicit Sqrt(Stream stream, bool recip = false) - : UnaryPrimitive(stream), recip_(recip){}; + : UnaryPrimitive(stream), recip_(recip) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1865,7 +1865,7 @@ class Sqrt : public UnaryPrimitive { class StopGradient : public UnaryPrimitive { public: - explicit StopGradient(Stream stream) : UnaryPrimitive(stream){}; + explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1881,7 +1881,7 @@ class StopGradient : public UnaryPrimitive { class Subtract : public UnaryPrimitive { public: - explicit Subtract(Stream stream) : UnaryPrimitive(stream){}; + explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1898,7 +1898,7 @@ class Subtract : public UnaryPrimitive { class Tan : public UnaryPrimitive { public: - explicit Tan(Stream stream) : UnaryPrimitive(stream){}; + explicit Tan(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1915,7 +1915,7 @@ class Tan : public UnaryPrimitive { class Tanh : public UnaryPrimitive { public: - explicit Tanh(Stream stream) : UnaryPrimitive(stream){}; + explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1932,7 +1932,7 @@ class Tanh : public UnaryPrimitive { class Uniform : public UnaryPrimitive { public: - explicit Uniform(Stream stream) : UnaryPrimitive(stream){}; + explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1948,7 +1948,7 @@ class Uniform : public UnaryPrimitive { class Transpose : public UnaryPrimitive { public: explicit Transpose(Stream stream, const std::vector& axes) - : UnaryPrimitive(stream), axes_(axes){}; + : UnaryPrimitive(stream), axes_(axes) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1967,7 +1967,7 @@ class Transpose : public UnaryPrimitive { /* QR Factorization primitive. */ class QRF : public Primitive { public: - explicit QRF(Stream stream) : Primitive(stream){}; + explicit QRF(Stream stream) : Primitive(stream) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -1983,7 +1983,7 @@ class QRF : public Primitive { /* SVD primitive. */ class SVD : public Primitive { public: - explicit SVD(Stream stream) : Primitive(stream){}; + explicit SVD(Stream stream) : Primitive(stream) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -2000,7 +2000,7 @@ class SVD : public Primitive { /* Matrix inversion primitive. */ class Inverse : public UnaryPrimitive { public: - explicit Inverse(Stream stream) : UnaryPrimitive(stream){}; + explicit Inverse(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& output) override; void eval_gpu(const std::vector& inputs, array& output) override; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 4c497ac9b..ace64a14a 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -22,7 +22,7 @@ namespace mlx::core { * for synchronizing with the main thread. */ class Synchronizer : public Primitive { public: - explicit Synchronizer(Stream stream) : Primitive(stream){}; + explicit Synchronizer(Stream stream) : Primitive(stream) {}; void eval_cpu(const std::vector&, std::vector&) override {}; void eval_gpu(const std::vector&, std::vector&) override {}; diff --git a/mlx/types/complex.h b/mlx/types/complex.h index dad0b322f..48bcdbff7 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -14,8 +14,8 @@ inline constexpr bool can_convert_to_complex128 = !std::is_same_v && std::is_convertible_v; struct complex128_t : public std::complex { - complex128_t(double v, double u) : std::complex(v, u){}; - complex128_t(std::complex v) : std::complex(v){}; + complex128_t(double v, double u) : std::complex(v, u) {}; + complex128_t(std::complex v) : std::complex(v) {}; template < typename T, @@ -32,8 +32,8 @@ inline constexpr bool can_convert_to_complex64 = !std::is_same_v && std::is_convertible_v; struct complex64_t : public std::complex { - complex64_t(float v, float u) : std::complex(v, u){}; - complex64_t(std::complex v) : std::complex(v){}; + complex64_t(float v, float u) : std::complex(v, u) {}; + complex64_t(std::complex v) : std::complex(v) {}; template < typename T,