From dd7d8e5e292f57bd9cedb36949cf8609160c2b55 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 12 Jun 2024 09:47:12 -0700 Subject: [PATCH] Add Quantized Ops to the JIT (#1204) * JIT for quantized ops * remove unused imports * address comments * fix imports * second attempt to fix imports --------- Co-authored-by: Alex Barron --- mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/fft.cpp | 31 +- mlx/backend/metal/jit/fft.h | 53 - mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 53 +- mlx/backend/metal/kernels.h | 37 +- mlx/backend/metal/kernels/CMakeLists.txt | 5 +- mlx/backend/metal/kernels/defines.h | 8 + mlx/backend/metal/kernels/fft.metal | 79 +- mlx/backend/metal/kernels/quantized.h | 1455 ++++++++++++++++ mlx/backend/metal/kernels/quantized.metal | 1904 ++------------------- mlx/backend/metal/nojit_kernels.cpp | 15 +- mlx/backend/metal/quantized.cpp | 84 +- 13 files changed, 1778 insertions(+), 1948 deletions(-) delete mode 100644 mlx/backend/metal/jit/fft.h create mode 100644 mlx/backend/metal/kernels/quantized.h diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index cbc18bb3a..8839237fe 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -112,6 +112,7 @@ if (MLX_METAL_JIT) kernels/steel/defines.h kernels/steel/conv/loaders/loader_general.h ) + make_jit_source(quantized) else() target_sources( mlx diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 394f3c272..817bf7fe4 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -661,34 +661,45 @@ void fft_op( std::ostringstream kname; std::string inv_string = inverse ? "true" : "false"; std::string real_string = real ? "true" : "false"; + std::string func_name; if (plan.bluestein_n > 0) { kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; + func_name = "bluestein_fft"; } else if (plan.rader_n > 1) { kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; + func_name = "rader_fft"; } else if (four_step_params.required) { step = four_step_params.first_step ? 0 : 1; kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str << "_" << step << "_" << real_string; + func_name = "four_step_fft"; } else { kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" << out_type_str; + func_name = "fft"; } std::string base_name = kname.str(); // We use a specialized kernel for each FFT size kname << "_n" << fft_size << "_inv_" << inverse; std::string hash_name = kname.str(); - auto kernel = get_fft_kernel( - d, - base_name, - hash_name, - threadgroup_mem_size, - in_type_str, - out_type_str, - step, - real, - func_consts); + auto template_def = func_name == "four_step_fft" ? get_template_definition( + base_name, + func_name, + threadgroup_mem_size, + in_type_str, + out_type_str, + step, + real) + : get_template_definition( + base_name, + func_name, + threadgroup_mem_size, + in_type_str, + out_type_str); + auto kernel = + get_fft_kernel(d, base_name, hash_name, func_consts, template_def); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(in_contiguous, 0); diff --git a/mlx/backend/metal/jit/fft.h b/mlx/backend/metal/jit/fft.h deleted file mode 100644 index 24f908db9..000000000 --- a/mlx/backend/metal/jit/fft.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view fft_kernel = R"( -template [[host_name("{name}")]] [[kernel]] void -fft<{tg_mem_size}, {in_T}, {out_T}>( - const device {in_T}* in [[buffer(0)]], - device {out_T}* out [[buffer(1)]], - constant const int& n, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]); -)"; - -constexpr std::string_view rader_fft_kernel = R"( -template [[host_name("{name}")]] [[kernel]] void -rader_fft<{tg_mem_size}, {in_T}, {out_T}>( - const device {in_T}* in [[buffer(0)]], - device {out_T}* out [[buffer(1)]], - const device float2* raders_b_q [[buffer(2)]], - const device short* raders_g_q [[buffer(3)]], - const device short* raders_g_minus_q [[buffer(4)]], - constant const int& n, - constant const int& batch_size, - constant const int& rader_n, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]); -)"; - -constexpr std::string_view bluestein_fft_kernel = R"( -template [[host_name("{name}")]] [[kernel]] void -bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>( - const device {in_T}* in [[buffer(0)]], - device {out_T}* out [[buffer(1)]], - const device float2* w_q [[buffer(2)]], - const device float2* w_k [[buffer(3)]], - constant const int& length, - constant const int& n, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]); -)"; - -constexpr std::string_view four_step_fft_kernel = R"( -template [[host_name("{name}")]] [[kernel]] void -four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>( - const device {in_T}* in [[buffer(0)]], - device {out_T}* out [[buffer(1)]], - constant const int& n1, - constant const int& n2, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]); -)"; \ No newline at end of file diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index f6b668512..f7e25c7c1 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -18,6 +18,7 @@ const char* binary(); const char* binary_two(); const char* copy(); const char* fft(); +const char* quantized(); const char* ternary(); const char* scan(); const char* softmax(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 1175ddbee..90495cd72 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,12 +1,10 @@ // Copyright © 2024 Apple Inc. -#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/binary.h" #include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/copy.h" -#include "mlx/backend/metal/jit/fft.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/scan.h" @@ -494,47 +492,32 @@ MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, - const int tg_mem_size, - const std::string& in_type, - const std::string& out_type, - int step, - bool real, - const metal::MTLFCList& func_consts) { + const metal::MTLFCList& func_consts, + const std::string& template_def) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; std::string kernel_string; - if (lib_name.find("bluestein") != std::string::npos) { - kernel_string = bluestein_fft_kernel; - } else if (lib_name.find("rader") != std::string::npos) { - kernel_string = rader_fft_kernel; - } else if (lib_name.find("four_step") != std::string::npos) { - kernel_string = four_step_fft_kernel; - } else { - kernel_string = fft_kernel; - } - kernel_source << metal::fft(); - if (lib_name.find("four_step") != std::string::npos) { - kernel_source << fmt::format( - kernel_string, - "name"_a = lib_name, - "tg_mem_size"_a = tg_mem_size, - "in_T"_a = in_type, - "out_T"_a = out_type, - "step"_a = step, - "real"_a = real); - } else { - kernel_source << fmt::format( - kernel_string, - "name"_a = lib_name, - "tg_mem_size"_a = tg_mem_size, - "in_T"_a = in_type, - "out_T"_a = out_type); - } + kernel_source << metal::fft() << template_def; lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_quantized_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& template_def) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gemm() << metal::quantized() + << template_def; + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index ce99464ef..936ebca24 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -1,5 +1,7 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/array.h" #include "mlx/backend/metal/device.h" @@ -159,11 +161,34 @@ MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, - const int tg_mem_size, - const std::string& in_type, - const std::string& out_type, - int step, - bool real, - const metal::MTLFCList& func_consts); + const metal::MTLFCList& func_consts, + const std::string& template_def); + +MTL::ComputePipelineState* get_quantized_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& template_def); + +// Create a GPU kernel template definition for JIT compilation +template +std::string +get_template_definition(std::string name, std::string func, Args... args) { + std::ostringstream s; + s << func << "<"; + bool first = true; + auto add_arg = [&s, &first](const auto& arg) { + if (!first) { + s << ", "; + } + first = false; + s << arg; + }; + (add_arg(args), ...); + s << ">"; + std::string base_string = R"( +template [[host_name("{0}")]] [[kernel]] decltype({1}) {1}; + )"; + return fmt::format(base_string, name, s.str()); +} } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index f98430eb8..81751e917 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -12,9 +12,7 @@ set( KERNELS "arg_reduce" "conv" - "fft" "gemv" - "quantized" "random" "rms_norm" "layer_norm" @@ -32,6 +30,8 @@ set( "unary" "ternary" "copy" + "fft" + "quantized" "softmax" "sort" "scan" @@ -51,6 +51,7 @@ set( fft.h fft/radix.h fft/readwrite.h + quantized.h softmax.h sort.h scan.h diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h index 3c4fcbdeb..3e98252b5 100644 --- a/mlx/backend/metal/kernels/defines.h +++ b/mlx/backend/metal/kernels/defines.h @@ -13,3 +13,11 @@ static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; diff --git a/mlx/backend/metal/kernels/fft.metal b/mlx/backend/metal/kernels/fft.metal index 05828f34c..590b558ef 100644 --- a/mlx/backend/metal/kernels/fft.metal +++ b/mlx/backend/metal/kernels/fft.metal @@ -1,58 +1,41 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/fft.h" -#define instantiate_fft(tg_mem_size, in_T, out_T) \ - template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \ - "_" #out_T)]] [[kernel]] void \ - fft( \ - const device in_T* in [[buffer(0)]], \ - device out_T* out [[buffer(1)]], \ - constant const int& n, \ - constant const int& batch_size, \ - uint3 elem [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); +#define instantiate_fft(tg_mem_size, in_T, out_T) \ + instantiate_kernel( \ + "fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ + fft, \ + tg_mem_size, \ + in_T, \ + out_T) -#define instantiate_rader(tg_mem_size, in_T, out_T) \ - template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \ - "_" #out_T)]] [[kernel]] void \ - rader_fft( \ - const device in_T* in [[buffer(0)]], \ - device out_T* out [[buffer(1)]], \ - const device float2* raders_b_q [[buffer(2)]], \ - const device short* raders_g_q [[buffer(3)]], \ - const device short* raders_g_minus_q [[buffer(4)]], \ - constant const int& n, \ - constant const int& batch_size, \ - constant const int& rader_n, \ - uint3 elem [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); +#define instantiate_rader(tg_mem_size, in_T, out_T) \ + instantiate_kernel( \ + "rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ + rader_fft, \ + tg_mem_size, \ + in_T, \ + out_T) -#define instantiate_bluestein(tg_mem_size, in_T, out_T) \ - template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \ - "_" #out_T)]] [[kernel]] void \ - bluestein_fft( \ - const device in_T* in [[buffer(0)]], \ - device out_T* out [[buffer(1)]], \ - const device float2* w_q [[buffer(2)]], \ - const device float2* w_k [[buffer(3)]], \ - constant const int& length, \ - constant const int& n, \ - constant const int& batch_size, \ - uint3 elem [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); +#define instantiate_bluestein(tg_mem_size, in_T, out_T) \ + instantiate_kernel( \ + "bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \ + bluestein_fft, \ + tg_mem_size, \ + in_T, \ + out_T) -#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \ - template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \ - "_" #step "_" #real)]] [[kernel]] void \ - four_step_fft( \ - const device in_T* in [[buffer(0)]], \ - device out_T* out [[buffer(1)]], \ - constant const int& n1, \ - constant const int& n2, \ - constant const int& batch_size, \ - uint3 elem [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); +#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \ + instantiate_kernel( \ + "four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \ + four_step_fft, \ + tg_mem_size, \ + in_T, \ + out_T, \ + step, \ + real) // clang-format off #define instantiate_ffts(tg_mem_size) \ diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h new file mode 100644 index 000000000..28a055576 --- /dev/null +++ b/mlx/backend/metal/kernels/quantized.h @@ -0,0 +1,1455 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + + MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint32_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint32_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld / pack_factor + bj), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template +METAL_FUNC void qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = bits > 2 ? 2 : 1; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + U sum = + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = + (const device uint8_t*)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void qvm_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = 32 / bits; + constexpr int tn = 32 / pack_factor; + constexpr int blocksize = SIMD_SIZE; + + typedef float U; + typedef struct { + uint32_t wi[tn]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = + tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; + w += out_col / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + biases += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.y * in_vec_size + simd_lid; + y += tid.y * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of blocksize + int remaining = in_vec_size % blocksize; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + } else { + for (int i = blocksize; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + } else { + x_local = 0; + scale = 0; + bias = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits, + const bool aligned_N> +METAL_FUNC void qmm_t_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + x += y_row * K; + w += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits> +METAL_FUNC void qmm_n_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = 32 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * K; + w += y_col / pack_factor; + scales += y_col / group_size; + biases += y_col / group_size; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant size_t* lhs_strides, + const constant size_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void bs_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void bs_qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void bs_qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 5bc612aae..0651db872 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -1,1518 +1,47 @@ // Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/defines.h" +// clang-format off #include "mlx/backend/metal/kernels/utils.h" - #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized.h" + + +#define instantiate_qmv_fast(itype, group_size, bits) \ + instantiate_kernel( \ + "qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ + qmv_fast, \ + itype, \ + group_size, \ + bits) + +#define instantiate_qmv_fast_types(group_size, bits) \ + instantiate_qmv_fast(float, group_size, bits) \ + instantiate_qmv_fast(float16_t, group_size, bits) \ + instantiate_qmv_fast(bfloat16_t, group_size, bits) + +instantiate_qmv_fast_types(128, 2) +instantiate_qmv_fast_types(128, 4) +instantiate_qmv_fast_types(128, 8) +instantiate_qmv_fast_types( 64, 2) +instantiate_qmv_fast_types( 64, 4) +instantiate_qmv_fast_types( 64, 8) +instantiate_qmv_fast_types( 32, 2) +instantiate_qmv_fast_types( 32, 4) +instantiate_qmv_fast_types( 32, 8) + +#define instantiate_qmv(itype, group_size, bits) \ + instantiate_kernel( \ + "qmv_" #itype "_gs_" #group_size "_b_" #bits, \ + qmv, \ + itype, \ + group_size, \ + bits) -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const - -MLX_MTL_CONST int SIMD_SIZE = 32; - -template -inline U load_vector(const device T* x, thread U* x_thread) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - } - - else if (bits == 4) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - } - - return sum; -} - -template -inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } - } - - else if (bits == 4) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } - } - - return sum; -} - -template -inline U qdot( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline U qdot_safe( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum, - int N) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline void -qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - if (bits == 2) { - U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; - for (int i = 0; i < (values_per_thread / 4); i++) { - result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); - result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); - result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); - result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / 16.0f}; - for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); - result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - result[i] += x * (scale * w[i] + bias); - } - } -} - -template -inline void -dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - if (bits == 2) { - U s[4] = { - scale, - scale / static_cast(4.0f), - scale / static_cast(16.0f), - scale / static_cast(64.0f)}; - for (int i = 0; i < (N / 4); i++) { - w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; - w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; - w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; - w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / static_cast(16.0f)}; - for (int i = 0; i < (N / 2); i++) { - w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; - w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - w_local[i] = scale * w[i] + bias; - } - } -} - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short group_size, - short bits> -struct QuantizedBlockLoader { - static_assert( - BCOLS <= group_size, - "The group size should be larger than the columns"); - static_assert( - group_size % BCOLS == 0, - "The group size should be divisible by the columns"); - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - - MLX_MTL_CONST short pack_factor = 32 / bits; - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - MLX_MTL_CONST short group_steps = group_size / BCOLS; - - const int src_ld; - const int tile_stride; - short group_step_cnt; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - threadgroup T* dst; - const device uint32_t* src; - const device T* scales; - const device T* biases; - - QuantizedBlockLoader( - const device uint32_t* src_, - const device T* scales_, - const device T* biases_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), - group_step_cnt(0), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld / pack_factor + bj), - scales(scales_ + bi * src_ld / group_size), - biases(biases_ + bi * src_ld / group_size) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - if (group_steps > 1) { - group_step_cnt++; - if (group_step_cnt == group_steps) { - group_step_cnt = 0; - scales++; - biases++; - } - } else { - scales++; - biases++; - } - } else { - scales += group_stride; - biases += group_stride; - } - } -}; - -template -METAL_FUNC void qmv_fast_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = 32 / bits; - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; - - for (int k = 0; k < in_vec_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); - } - - w += block_size / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } -} - -template -METAL_FUNC void qmv_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int packs_per_thread = 1; - constexpr int pack_factor = 32 / bits; - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); - - if (out_row >= out_vec_size) { - return; - } - - // In this case we need to properly guard all our reads because there isn't - // even 1 tile in the matrix - if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); - } - - w += block_size / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); - - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); - } - - for (int row = 0; out_row + row < out_vec_size; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } - - // In this case the last tile is moved back to redo some output values - else { - w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; - scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + used_out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); - } - - w += block_size / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); - - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); - } - - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } -} - -template -METAL_FUNC void qvm_impl( - const device T* x, - const device uint32_t* w, - const device T* scales, - const device T* biases, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int pack_factor = 32 / bits; - constexpr int tn = 32 / pack_factor; - constexpr int blocksize = SIMD_SIZE; - - typedef float U; - typedef struct { - uint32_t wi[tn]; - } vec_w; - - thread vec_w w_local; - thread U result[tn * pack_factor] = {0}; - thread U scale = 1; - thread U bias = 0; - thread U x_local = 0; - - // Adjust positions - const int out_vec_size_w = out_vec_size / pack_factor; - const int out_vec_size_g = out_vec_size / group_size; - int out_col = - tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; - w += out_col / pack_factor + simd_lid * out_vec_size_w; - scales += out_col / group_size + simd_lid * out_vec_size_g; - biases += out_col / group_size + simd_lid * out_vec_size_g; - x += tid.y * in_vec_size + simd_lid; - y += tid.y * out_vec_size + out_col; - - if (out_col >= out_vec_size) { - return; - } - - // Loop over in_vec in blocks of blocksize - int remaining = in_vec_size % blocksize; - if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += blocksize) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)w); - - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; - } - } else { - for (int i = blocksize; i < in_vec_size; i += blocksize) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)w); - - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; - } - if (static_cast(simd_lid) < remaining) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)w); - } else { - x_local = 0; - scale = 0; - bias = 0; - } - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - } - -// Accumulate in the simdgroup -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - result[k] = simd_sum(result[k]); - } - - // Store the result - if (simd_lid == 0) { -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - y[k] = static_cast(result[k]); - } - } -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits, - const bool aligned_N> -METAL_FUNC void qmm_t_impl( - const device T* x, - const device uint32_t* w, - const device T* scales, - const device T* biases, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& M, - const constant int& N, - const constant int& K, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = - mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BN, - BK, - BK_padded, - 1, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - x += y_row * K; - w += y_col * K_w; - scales += y_col * K_g; - biases += y_col * K_g; - y += y_row * N + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - const short num_outs = min(BN, N - y_col); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM || num_outs < BN) { - mma_op.store_result_safe(y, N, short2(num_outs, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits> -METAL_FUNC void qmm_n_impl( - const device T* x, - const device uint32_t* w, - const device T* scales, - const device T* biases, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& M, - const constant int& N, - const constant int& K, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = mlx::steel:: - BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BK, - BN, - BN_padded, - 0, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - x += y_row * K; - w += y_col / pack_factor; - scales += y_col / group_size; - biases += y_col / group_size; - y += y_row * N + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, num_els)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, BM)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM) { - mma_op.store_result_safe(y, N, short2(BN, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device T*& scales, - const device T*& biases, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T*& y, - int output_stride, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant size_t* lhs_strides, - const constant size_t* rhs_strides, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant size_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant size_t* w_strides, - const constant size_t* s_strides, - const constant size_t* b_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx; - uint32_t w_idx; - if (batch_ndims == 1) { - x_idx = lhs_indices[tid.z * lhs_strides[0]]; - w_idx = rhs_indices[tid.z * rhs_strides[0]]; - } else { - ulong2 idx = elem_to_loc_broadcast( - tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); - x_idx = lhs_indices[idx.x]; - w_idx = rhs_indices[idx.y]; - } - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - biases += w_idx * b_strides[0]; - } else { - ulong3 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - biases += idx.z; - } - y += tid.z * output_stride; -} - -template -[[kernel]] void qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void qmv( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - qmv_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - qvm_impl( - x, - w, - scales, - biases, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits, - const bool aligned_N> -[[kernel]] void qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits> -[[kernel]] void qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); -} - -template -[[kernel]] void bs_qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void bs_qmv( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmv_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void bs_qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qvm_impl( - x, - w, - scales, - biases, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits, - const bool aligned_N> -[[kernel]] void bs_qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int BM, - const int BK, - const int BN, - const int group_size, - const int bits> -[[kernel]] void bs_qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); -} - -#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \ - template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \ - "_fast")]] [[kernel]] void \ - qmv_fast( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -// clang-format off -#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \ - instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ - instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ - instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on - -// clang-format off -instantiate_qmv_fast_types(128, 2, 1) -instantiate_qmv_fast_types(128, 4, 2) -instantiate_qmv_fast_types(128, 8, 2) -instantiate_qmv_fast_types( 64, 2, 1) -instantiate_qmv_fast_types( 64, 4, 2) -instantiate_qmv_fast_types( 64, 8, 2) -instantiate_qmv_fast_types( 32, 2, 1) -instantiate_qmv_fast_types( 32, 4, 2) -instantiate_qmv_fast_types( 32, 8, 2) // clang-format on - -#define instantiate_qmv(name, itype, group_size, bits) \ - template [[host_name("qmv_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - qmv( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -// clang-format off #define instantiate_qmv_types(group_size, bits) \ - instantiate_qmv(float32, float, group_size, bits) \ - instantiate_qmv(float16, half, group_size, bits) \ - instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_qmv(float, group_size, bits) \ + instantiate_qmv(float16_t, group_size, bits) \ + instantiate_qmv(bfloat16_t, group_size, bits) - // clang-format off instantiate_qmv_types(128, 2) instantiate_qmv_types(128, 4) instantiate_qmv_types(128, 8) @@ -1521,30 +50,21 @@ instantiate_qmv_types( 64, 4) instantiate_qmv_types( 64, 8) instantiate_qmv_types( 32, 2) instantiate_qmv_types( 32, 4) -instantiate_qmv_types( 32, 8) // clang-format on +instantiate_qmv_types( 32, 8) -#define instantiate_qvm(name, itype, group_size, bits) \ - template [[host_name("qvm_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - qvm( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& in_vec_size [[buffer(5)]], \ - const constant int& out_vec_size [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_qvm(itype, group_size, bits) \ + instantiate_kernel( \ + "qvm_" #itype "_gs_" #group_size "_b_" #bits, \ + qvm, \ + itype, \ + group_size, \ + bits) -// clang-format off #define instantiate_qvm_types(group_size, bits) \ - instantiate_qvm(float32, float, group_size, bits) \ - instantiate_qvm(float16, half, group_size, bits) \ - instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_qvm(float, group_size, bits) \ + instantiate_qvm(float16_t, group_size, bits) \ + instantiate_qvm(bfloat16_t, group_size, bits) - // clang-format off instantiate_qvm_types(128, 2) instantiate_qvm_types(128, 4) instantiate_qvm_types(128, 8) @@ -1553,35 +73,25 @@ instantiate_qvm_types( 64, 4) instantiate_qvm_types( 64, 8) instantiate_qvm_types( 32, 2) instantiate_qvm_types( 32, 4) -instantiate_qvm_types( 32, 8) // clang-format on +instantiate_qvm_types( 32, 8) -#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ - template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \ - "_alN_" #aligned_N)]] [[kernel]] void \ - qmm_t( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& M [[buffer(5)]], \ - const constant int& N [[buffer(6)]], \ - const constant int& K [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \ + instantiate_kernel( \ + "qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \ + qmm_t, \ + itype, \ + group_size, \ + bits, \ + aligned_N) -// clang-format off #define instantiate_qmm_t_types(group_size, bits) \ - instantiate_qmm_t(float32, float, group_size, bits, false) \ - instantiate_qmm_t(float16, half, group_size, bits, false) \ - instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ - instantiate_qmm_t(float32, float, group_size, bits, true) \ - instantiate_qmm_t(float16, half, group_size, bits, true) \ - instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on + instantiate_qmm_t(float, group_size, bits, false) \ + instantiate_qmm_t(float16_t, group_size, bits, false) \ + instantiate_qmm_t(bfloat16_t, group_size, bits, false) \ + instantiate_qmm_t(float, group_size, bits, true) \ + instantiate_qmm_t(float16_t, group_size, bits, true) \ + instantiate_qmm_t(bfloat16_t, group_size, bits, true) - // clang-format off instantiate_qmm_t_types(128, 2) instantiate_qmm_t_types(128, 4) instantiate_qmm_t_types(128, 8) @@ -1590,32 +100,21 @@ instantiate_qmm_t_types( 64, 4) instantiate_qmm_t_types( 64, 8) instantiate_qmm_t_types( 32, 2) instantiate_qmm_t_types( 32, 4) -instantiate_qmm_t_types( 32, 8) // clang-format on +instantiate_qmm_t_types( 32, 8) -#define instantiate_qmm_n(name, itype, group_size, bits) \ - template [[host_name("qmm_n_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - qmm_n( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - device itype* y [[buffer(4)]], \ - const constant int& M [[buffer(5)]], \ - const constant int& N [[buffer(6)]], \ - const constant int& K [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_qmm_n(itype, group_size, bits) \ + instantiate_kernel( \ + "qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \ + qmm_n, \ + itype, \ + group_size, \ + bits) -// clang-format off #define instantiate_qmm_n_types(group_size, bits) \ - instantiate_qmm_n(float32, float, group_size, bits) \ - instantiate_qmm_n(float16, half, group_size, bits) \ - instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_qmm_n(float, group_size, bits) \ + instantiate_qmm_n(float16_t, group_size, bits) \ + instantiate_qmm_n(bfloat16_t, group_size, bits) - // clang-format off instantiate_qmm_n_types(128, 2) instantiate_qmm_n_types(128, 4) instantiate_qmm_n_types(128, 8) @@ -1624,91 +123,44 @@ instantiate_qmm_n_types( 64, 4) instantiate_qmm_n_types( 64, 8) instantiate_qmm_n_types( 32, 2) instantiate_qmm_n_types( 32, 4) -instantiate_qmm_n_types( 32, 8) // clang-format on +instantiate_qmm_n_types( 32, 8) -#define instantiate_bs_qmv_fast( \ - name, itype, group_size, bits, packs_per_thread) \ - template [[host_name("bs_qmv_" #name "_gs_" #group_size "_b_" #bits \ - "_fast")]] [[kernel]] void \ - bs_qmv_fast( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - const device uint32_t* lhs_indices [[buffer(4)]], \ - const device uint32_t* rhs_indices [[buffer(5)]], \ - device itype* y [[buffer(6)]], \ - const constant int& in_vec_size [[buffer(7)]], \ - const constant int& out_vec_size [[buffer(8)]], \ - const constant int& batch_ndims [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* lhs_strides [[buffer(11)]], \ - const constant size_t* rhs_strides [[buffer(12)]], \ - const constant int& x_batch_ndims [[buffer(13)]], \ - const constant int* x_shape [[buffer(14)]], \ - const constant size_t* x_strides [[buffer(15)]], \ - const constant int& w_batch_ndims [[buffer(16)]], \ - const constant int* w_shape [[buffer(17)]], \ - const constant size_t* w_strides [[buffer(18)]], \ - const constant size_t* s_strides [[buffer(19)]], \ - const constant size_t* b_strides [[buffer(20)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_bs_qmv_fast(itype, group_size, bits) \ + instantiate_kernel( \ + "bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ + bs_qmv_fast, \ + itype, \ + group_size, \ + bits) -// clang-format off -#define instantiate_bs_qmv_fast_types(group_size, bits, packs_per_thread) \ - instantiate_bs_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ - instantiate_bs_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ - instantiate_bs_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on +#define instantiate_bs_qmv_fast_types(group_size, bits) \ + instantiate_bs_qmv_fast(float, group_size, bits) \ + instantiate_bs_qmv_fast(float16_t, group_size, bits) \ + instantiate_bs_qmv_fast(bfloat16_t, group_size, bits) - // clang-format off -instantiate_bs_qmv_fast_types(128, 2, 1) -instantiate_bs_qmv_fast_types(128, 4, 2) -instantiate_bs_qmv_fast_types(128, 8, 2) -instantiate_bs_qmv_fast_types( 64, 2, 1) -instantiate_bs_qmv_fast_types( 64, 4, 2) -instantiate_bs_qmv_fast_types( 64, 8, 2) -instantiate_bs_qmv_fast_types( 32, 2, 1) -instantiate_bs_qmv_fast_types( 32, 4, 2) -instantiate_bs_qmv_fast_types( 32, 8, 2) // clang-format on +instantiate_bs_qmv_fast_types(128, 2) +instantiate_bs_qmv_fast_types(128, 4) +instantiate_bs_qmv_fast_types(128, 8) +instantiate_bs_qmv_fast_types( 64, 2) +instantiate_bs_qmv_fast_types( 64, 4) +instantiate_bs_qmv_fast_types( 64, 8) +instantiate_bs_qmv_fast_types( 32, 2) +instantiate_bs_qmv_fast_types( 32, 4) +instantiate_bs_qmv_fast_types( 32, 8) -#define instantiate_bs_qmv(name, itype, group_size, bits) \ - template [[host_name("bs_qmv_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - bs_qmv( \ - const device uint32_t* w [[buffer(0)]], \ - const device itype* scales [[buffer(1)]], \ - const device itype* biases [[buffer(2)]], \ - const device itype* x [[buffer(3)]], \ - const device uint32_t* lhs_indices [[buffer(4)]], \ - const device uint32_t* rhs_indices [[buffer(5)]], \ - device itype* y [[buffer(6)]], \ - const constant int& in_vec_size [[buffer(7)]], \ - const constant int& out_vec_size [[buffer(8)]], \ - const constant int& batch_ndims [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* lhs_strides [[buffer(11)]], \ - const constant size_t* rhs_strides [[buffer(12)]], \ - const constant int& x_batch_ndims [[buffer(13)]], \ - const constant int* x_shape [[buffer(14)]], \ - const constant size_t* x_strides [[buffer(15)]], \ - const constant int& w_batch_ndims [[buffer(16)]], \ - const constant int* w_shape [[buffer(17)]], \ - const constant size_t* w_strides [[buffer(18)]], \ - const constant size_t* s_strides [[buffer(19)]], \ - const constant size_t* b_strides [[buffer(20)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_bs_qmv(itype, group_size, bits) \ + instantiate_kernel( \ + "bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \ + bs_qmv, \ + itype, \ + group_size, \ + bits) -// clang-format off #define instantiate_bs_qmv_types(group_size, bits) \ - instantiate_bs_qmv(float32, float, group_size, bits) \ - instantiate_bs_qmv(float16, half, group_size, bits) \ - instantiate_bs_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_bs_qmv(float, group_size, bits) \ + instantiate_bs_qmv(float16_t, group_size, bits) \ + instantiate_bs_qmv(bfloat16_t, group_size, bits) - // clang-format off instantiate_bs_qmv_types(128, 2) instantiate_bs_qmv_types(128, 4) instantiate_bs_qmv_types(128, 8) @@ -1717,44 +169,21 @@ instantiate_bs_qmv_types( 64, 4) instantiate_bs_qmv_types( 64, 8) instantiate_bs_qmv_types( 32, 2) instantiate_bs_qmv_types( 32, 4) -instantiate_bs_qmv_types( 32, 8) // clang-format on +instantiate_bs_qmv_types( 32, 8) -#define instantiate_bs_qvm(name, itype, group_size, bits) \ - template [[host_name("bs_qvm_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - bs_qvm( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - const device uint32_t* lhs_indices [[buffer(4)]], \ - const device uint32_t* rhs_indices [[buffer(5)]], \ - device itype* y [[buffer(6)]], \ - const constant int& in_vec_size [[buffer(7)]], \ - const constant int& out_vec_size [[buffer(8)]], \ - const constant int& batch_ndims [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* lhs_strides [[buffer(11)]], \ - const constant size_t* rhs_strides [[buffer(12)]], \ - const constant int& x_batch_ndims [[buffer(13)]], \ - const constant int* x_shape [[buffer(14)]], \ - const constant size_t* x_strides [[buffer(15)]], \ - const constant int& w_batch_ndims [[buffer(16)]], \ - const constant int* w_shape [[buffer(17)]], \ - const constant size_t* w_strides [[buffer(18)]], \ - const constant size_t* s_strides [[buffer(19)]], \ - const constant size_t* b_strides [[buffer(20)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_bs_qvm(itype, group_size, bits) \ + instantiate_kernel( \ + "bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \ + bs_qvm, \ + itype, \ + group_size, \ + bits) -// clang-format off #define instantiate_bs_qvm_types(group_size, bits) \ - instantiate_bs_qvm(float32, float, group_size, bits) \ - instantiate_bs_qvm(float16, half, group_size, bits) \ - instantiate_bs_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_bs_qvm(float, group_size, bits) \ + instantiate_bs_qvm(float16_t, group_size, bits) \ + instantiate_bs_qvm(bfloat16_t, group_size, bits) - // clang-format off instantiate_bs_qvm_types(128, 2) instantiate_bs_qvm_types(128, 4) instantiate_bs_qvm_types(128, 8) @@ -1763,49 +192,25 @@ instantiate_bs_qvm_types( 64, 4) instantiate_bs_qvm_types( 64, 8) instantiate_bs_qvm_types( 32, 2) instantiate_bs_qvm_types( 32, 4) -instantiate_bs_qvm_types( 32, 8) // clang-format on +instantiate_bs_qvm_types( 32, 8) -#define instantiate_bs_qmm_t(name, itype, group_size, bits, aligned_N) \ - template [[host_name("bs_qmm_t_" #name "_gs_" #group_size "_b_" #bits \ - "_alN_" #aligned_N)]] [[kernel]] void \ - bs_qmm_t( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - const device uint32_t* lhs_indices [[buffer(4)]], \ - const device uint32_t* rhs_indices [[buffer(5)]], \ - device itype* y [[buffer(6)]], \ - const constant int& M [[buffer(7)]], \ - const constant int& N [[buffer(8)]], \ - const constant int& K [[buffer(9)]], \ - const constant int& batch_ndims [[buffer(10)]], \ - const constant int* batch_shape [[buffer(11)]], \ - const constant size_t* lhs_strides [[buffer(12)]], \ - const constant size_t* rhs_strides [[buffer(13)]], \ - const constant int& x_batch_ndims [[buffer(14)]], \ - const constant int* x_shape [[buffer(15)]], \ - const constant size_t* x_strides [[buffer(16)]], \ - const constant int& w_batch_ndims [[buffer(17)]], \ - const constant int* w_shape [[buffer(18)]], \ - const constant size_t* w_strides [[buffer(19)]], \ - const constant size_t* s_strides [[buffer(20)]], \ - const constant size_t* b_strides [[buffer(21)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \ + instantiate_kernel( \ + "bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \ + bs_qmm_t, \ + itype, \ + group_size, \ + bits, \ + aligned_N) -// clang-format off #define instantiate_bs_qmm_t_types(group_size, bits) \ - instantiate_bs_qmm_t(float32, float, group_size, bits, false) \ - instantiate_bs_qmm_t(float16, half, group_size, bits, false) \ - instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ - instantiate_bs_qmm_t(float32, float, group_size, bits, true) \ - instantiate_bs_qmm_t(float16, half, group_size, bits, true) \ - instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on + instantiate_bs_qmm_t(float, group_size, bits, false) \ + instantiate_bs_qmm_t(float16_t, group_size, bits, false) \ + instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \ + instantiate_bs_qmm_t(float, group_size, bits, true) \ + instantiate_bs_qmm_t(float16_t, group_size, bits, true) \ + instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true) - // clang-format off instantiate_bs_qmm_t_types(128, 2) instantiate_bs_qmm_t_types(128, 4) instantiate_bs_qmm_t_types(128, 8) @@ -1814,46 +219,21 @@ instantiate_bs_qmm_t_types( 64, 4) instantiate_bs_qmm_t_types( 64, 8) instantiate_bs_qmm_t_types( 32, 2) instantiate_bs_qmm_t_types( 32, 4) -instantiate_bs_qmm_t_types( 32, 8) // clang-format on +instantiate_bs_qmm_t_types( 32, 8) -#define instantiate_bs_qmm_n(name, itype, group_size, bits) \ - template [[host_name("bs_qmm_n_" #name "_gs_" #group_size \ - "_b_" #bits)]] [[kernel]] void \ - bs_qmm_n( \ - const device itype* x [[buffer(0)]], \ - const device uint32_t* w [[buffer(1)]], \ - const device itype* scales [[buffer(2)]], \ - const device itype* biases [[buffer(3)]], \ - const device uint32_t* lhs_indices [[buffer(4)]], \ - const device uint32_t* rhs_indices [[buffer(5)]], \ - device itype* y [[buffer(6)]], \ - const constant int& M [[buffer(7)]], \ - const constant int& N [[buffer(8)]], \ - const constant int& K [[buffer(9)]], \ - const constant int& batch_ndims [[buffer(10)]], \ - const constant int* batch_shape [[buffer(11)]], \ - const constant size_t* lhs_strides [[buffer(12)]], \ - const constant size_t* rhs_strides [[buffer(13)]], \ - const constant int& x_batch_ndims [[buffer(14)]], \ - const constant int* x_shape [[buffer(15)]], \ - const constant size_t* x_strides [[buffer(16)]], \ - const constant int& w_batch_ndims [[buffer(17)]], \ - const constant int* w_shape [[buffer(18)]], \ - const constant size_t* w_strides [[buffer(19)]], \ - const constant size_t* s_strides [[buffer(20)]], \ - const constant size_t* b_strides [[buffer(21)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_bs_qmm_n(itype, group_size, bits) \ + instantiate_kernel( \ + "bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \ + bs_qmm_n, \ + itype, \ + group_size, \ + bits) -// clang-format off #define instantiate_bs_qmm_n_types(group_size, bits) \ - instantiate_bs_qmm_n(float32, float, group_size, bits) \ - instantiate_bs_qmm_n(float16, half, group_size, bits) \ - instantiate_bs_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on + instantiate_bs_qmm_n(float, group_size, bits) \ + instantiate_bs_qmm_n(float16_t, group_size, bits) \ + instantiate_bs_qmm_n(bfloat16_t, group_size, bits) - // clang-format off instantiate_bs_qmm_n_types(128, 2) instantiate_bs_qmm_n_types(128, 4) instantiate_bs_qmm_n_types(128, 8) diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index d789bf2e9..28ab672e5 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -195,13 +195,16 @@ MTL::ComputePipelineState* get_fft_kernel( metal::Device& d, const std::string& kernel_name, const std::string& hash_name, - const int tg_mem_size, - const std::string& in_type, - const std::string& out_type, - int step, - bool real, - const metal::MTLFCList& func_consts) { + const metal::MTLFCList& func_consts, + const std::string&) { return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); } +MTL::ComputePipelineState* get_quantized_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string&) { + return d.get_kernel(kernel_name); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 48f1387a9..f0a64d5e6 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -2,8 +2,10 @@ #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" @@ -44,12 +46,15 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Route to the fast qmv kernel that has no bounds checking if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { std::ostringstream kname; - kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" - << bits_ << "_fast"; + auto type_string = get_type_string(x.dtype()); + kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ + << "_fast"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "qmv_fast", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 8; @@ -71,12 +76,14 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Route to the qmv kernel else if (B < 6) { std::ostringstream kname; - kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" - << bits_; + auto type_string = get_type_string(x.dtype()); + kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "qmv", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 8; @@ -98,12 +105,16 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Route to the qmm_t kernel else { std::ostringstream kname; - kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" - << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); + std::string aligned_n = (O % 32) == 0 ? "true" : "false"; + auto type_string = get_type_string(x.dtype()); + kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_ << "_alN_" << aligned_n; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int wn = 2; @@ -129,12 +140,14 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Route to the qvm kernel if (B < 4) { std::ostringstream kname; - kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" - << bits_; + auto type_string = get_type_string(x.dtype()); + kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "qvm", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 64; @@ -156,12 +169,15 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Route to the qmm_n kernel else { std::ostringstream kname; - kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + auto type_string = get_type_string(x.dtype()); + kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "qmm_n", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int wn = 2; @@ -253,12 +269,15 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { // Route to the fast bs_qmv kernel that has no bounds checking if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { std::ostringstream kname; - kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + auto type_string = get_type_string(x.dtype()); + kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ << "_fast"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "bs_qmv_fast", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 8; @@ -295,12 +314,15 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { else if (B < 6) { std::ostringstream kname; - kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + auto type_string = get_type_string(x.dtype()); + kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "bs_qmv", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 8; @@ -338,12 +360,16 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { // Route to the bs_qmm_t else { std::ostringstream kname; - kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_ - << "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); + std::string aligned_n = (O % 32) == 0 ? "true" : "false"; + auto type_string = get_type_string(out.dtype()); + kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_ << "_alN_" << aligned_n; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int wn = 2; @@ -385,12 +411,15 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { // Route to the bs_qvm kernel if (B < 4) { std::ostringstream kname; - kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + auto type_string = get_type_string(out.dtype()); + kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "bs_qvm", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int bo = 64; @@ -428,12 +457,15 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { // Route to bs_qmm_n else { std::ostringstream kname; - kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_ - << "_b_" << bits_; + auto type_string = get_type_string(out.dtype()); + kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto template_def = get_template_definition( + kname.str(), "bs_qmm_n", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); int wn = 2;