Add Quantized Ops to the JIT (#1204)

* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-06-12 09:47:12 -07:00 committed by GitHub
parent df964132fb
commit dd7d8e5e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1778 additions and 1948 deletions

View File

@ -112,6 +112,7 @@ if (MLX_METAL_JIT)
kernels/steel/defines.h kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h kernels/steel/conv/loaders/loader_general.h
) )
make_jit_source(quantized)
else() else()
target_sources( target_sources(
mlx mlx

View File

@ -661,34 +661,45 @@ void fft_op(
std::ostringstream kname; std::ostringstream kname;
std::string inv_string = inverse ? "true" : "false"; std::string inv_string = inverse ? "true" : "false";
std::string real_string = real ? "true" : "false"; std::string real_string = real ? "true" : "false";
std::string func_name;
if (plan.bluestein_n > 0) { if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_" kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str; << in_type_str << "_" << out_type_str;
func_name = "bluestein_fft";
} else if (plan.rader_n > 1) { } else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str; << "_" << out_type_str;
func_name = "rader_fft";
} else if (four_step_params.required) { } else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1; step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string; << "_" << out_type_str << "_" << step << "_" << real_string;
func_name = "four_step_fft";
} else { } else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str; << out_type_str;
func_name = "fft";
} }
std::string base_name = kname.str(); std::string base_name = kname.str();
// We use a specialized kernel for each FFT size // We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse; kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str(); std::string hash_name = kname.str();
auto kernel = get_fft_kernel( auto template_def = func_name == "four_step_fft" ? get_template_definition(
d, base_name,
base_name, func_name,
hash_name, threadgroup_mem_size,
threadgroup_mem_size, in_type_str,
in_type_str, out_type_str,
out_type_str, step,
step, real)
real, : get_template_definition(
func_consts); base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str);
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_input_array(in_contiguous, 0);

View File

@ -1,53 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view rader_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view bluestein_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view four_step_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";

View File

@ -18,6 +18,7 @@ const char* binary();
const char* binary_two(); const char* binary_two();
const char* copy(); const char* copy();
const char* fft(); const char* fft();
const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
const char* softmax(); const char* softmax();

View File

@ -1,12 +1,10 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h" #include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/fft.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
@ -494,47 +492,32 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& hash_name, const std::string& hash_name,
const int tg_mem_size, const metal::MTLFCList& func_consts,
const std::string& in_type, const std::string& template_def) {
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
std::string kernel_string; std::string kernel_string;
if (lib_name.find("bluestein") != std::string::npos) { kernel_source << metal::fft() << template_def;
kernel_string = bluestein_fft_kernel;
} else if (lib_name.find("rader") != std::string::npos) {
kernel_string = rader_fft_kernel;
} else if (lib_name.find("four_step") != std::string::npos) {
kernel_string = four_step_fft_kernel;
} else {
kernel_string = fft_kernel;
}
kernel_source << metal::fft();
if (lib_name.find("four_step") != std::string::npos) {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type,
"step"_a = step,
"real"_a = real);
} else {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type);
}
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -159,11 +161,34 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& hash_name, const std::string& hash_name,
const int tg_mem_size, const metal::MTLFCList& func_consts,
const std::string& in_type, const std::string& template_def);
const std::string& out_type,
int step, MTL::ComputePipelineState* get_quantized_kernel(
bool real, metal::Device& d,
const metal::MTLFCList& func_consts); const std::string& kernel_name,
const std::string& template_def);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string
get_template_definition(std::string name, std::string func, Args... args) {
std::ostringstream s;
s << func << "<";
bool first = true;
auto add_arg = [&s, &first](const auto& arg) {
if (!first) {
s << ", ";
}
first = false;
s << arg;
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -12,9 +12,7 @@ set(
KERNELS KERNELS
"arg_reduce" "arg_reduce"
"conv" "conv"
"fft"
"gemv" "gemv"
"quantized"
"random" "random"
"rms_norm" "rms_norm"
"layer_norm" "layer_norm"
@ -32,6 +30,8 @@ set(
"unary" "unary"
"ternary" "ternary"
"copy" "copy"
"fft"
"quantized"
"softmax" "softmax"
"sort" "sort"
"scan" "scan"
@ -51,6 +51,7 @@ set(
fft.h fft.h
fft/radix.h fft/radix.h
fft/readwrite.h fft/readwrite.h
quantized.h
softmax.h softmax.h
sort.h sort.h
scan.h scan.h

View File

@ -13,3 +13,11 @@ static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
// Instantiate a templated kernel.
// Extra args are used as template parameters:
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
// [[host_name(binary_int)]] [kernel] binary<a, b>
#define instantiate_kernel(name, func, ...) \
template [[host_name( \
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;

View File

@ -1,58 +1,41 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/fft.h" #include "mlx/backend/metal/kernels/fft.h"
#define instantiate_fft(tg_mem_size, in_T, out_T) \ #define instantiate_fft(tg_mem_size, in_T, out_T) \
template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \ instantiate_kernel( \
"_" #out_T)]] [[kernel]] void \ "fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
fft<tg_mem_size, in_T, out_T>( \ fft, \
const device in_T* in [[buffer(0)]], \ tg_mem_size, \
device out_T* out [[buffer(1)]], \ in_T, \
constant const int& n, \ out_T)
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rader(tg_mem_size, in_T, out_T) \ #define instantiate_rader(tg_mem_size, in_T, out_T) \
template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \ instantiate_kernel( \
"_" #out_T)]] [[kernel]] void \ "rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
rader_fft<tg_mem_size, in_T, out_T>( \ rader_fft, \
const device in_T* in [[buffer(0)]], \ tg_mem_size, \
device out_T* out [[buffer(1)]], \ in_T, \
const device float2* raders_b_q [[buffer(2)]], \ out_T)
const device short* raders_g_q [[buffer(3)]], \
const device short* raders_g_minus_q [[buffer(4)]], \
constant const int& n, \
constant const int& batch_size, \
constant const int& rader_n, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \ #define instantiate_bluestein(tg_mem_size, in_T, out_T) \
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \ instantiate_kernel( \
"_" #out_T)]] [[kernel]] void \ "bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
bluestein_fft<tg_mem_size, in_T, out_T>( \ bluestein_fft, \
const device in_T* in [[buffer(0)]], \ tg_mem_size, \
device out_T* out [[buffer(1)]], \ in_T, \
const device float2* w_q [[buffer(2)]], \ out_T)
const device float2* w_k [[buffer(3)]], \
constant const int& length, \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \ #define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \ instantiate_kernel( \
"_" #step "_" #real)]] [[kernel]] void \ "four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \
four_step_fft<tg_mem_size, in_T, out_T, step, real>( \ four_step_fft, \
const device in_T* in [[buffer(0)]], \ tg_mem_size, \
device out_T* out [[buffer(1)]], \ in_T, \
constant const int& n1, \ out_T, \
constant const int& n2, \ step, \
constant const int& batch_size, \ real)
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
// clang-format off // clang-format off
#define instantiate_ffts(tg_mem_size) \ #define instantiate_ffts(tg_mem_size) \

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -195,13 +195,16 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& hash_name, const std::string& hash_name,
const int tg_mem_size, const metal::MTLFCList& func_consts,
const std::string& in_type, const std::string&) {
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
} }
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&) {
return d.get_kernel(kernel_name);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -2,8 +2,10 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -44,12 +46,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast qmv kernel that has no bounds checking // Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname; std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
<< bits_ << "_fast"; kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast";
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 8;
@ -71,12 +76,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmv kernel // Route to the qmv kernel
else if (B < 6) { else if (B < 6) {
std::ostringstream kname; std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
<< bits_; kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 8;
@ -98,12 +105,16 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_t kernel // Route to the qmm_t kernel
else { else {
std::ostringstream kname; std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" std::string aligned_n = (O % 32) == 0 ? "true" : "false";
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int wn = 2; int wn = 2;
@ -129,12 +140,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qvm kernel // Route to the qvm kernel
if (B < 4) { if (B < 4) {
std::ostringstream kname; std::ostringstream kname;
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
<< bits_; kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 64; int bo = 64;
@ -156,12 +169,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_n kernel // Route to the qmm_n kernel
else { else {
std::ostringstream kname; std::ostringstream kname;
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_; << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int wn = 2; int wn = 2;
@ -253,12 +269,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast bs_qmv kernel that has no bounds checking // Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname; std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast"; << bits_ << "_fast";
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 8;
@ -295,12 +314,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (B < 6) { else if (B < 6) {
std::ostringstream kname; std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_; << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "bs_qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 8;
@ -338,12 +360,16 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qmm_t // Route to the bs_qmm_t
else { else {
std::ostringstream kname; std::ostringstream kname;
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_ std::string aligned_n = (O % 32) == 0 ? "true" : "false";
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int wn = 2; int wn = 2;
@ -385,12 +411,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qvm kernel // Route to the bs_qvm kernel
if (B < 4) { if (B < 4) {
std::ostringstream kname; std::ostringstream kname;
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" auto type_string = get_type_string(out.dtype());
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_; << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "bs_qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 64; int bo = 64;
@ -428,12 +457,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to bs_qmm_n // Route to bs_qmm_n
else { else {
std::ostringstream kname; std::ostringstream kname;
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_ auto type_string = get_type_string(out.dtype());
<< "_b_" << bits_; kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto template_def = get_template_definition(
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int wn = 2; int wn = 2;