Add Quantized Ops to the JIT (#1204)

* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-06-12 09:47:12 -07:00 committed by GitHub
parent df964132fb
commit dd7d8e5e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1778 additions and 1948 deletions

View File

@ -112,6 +112,7 @@ if (MLX_METAL_JIT)
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized)
else()
target_sources(
mlx

View File

@ -661,34 +661,45 @@ void fft_op(
std::ostringstream kname;
std::string inv_string = inverse ? "true" : "false";
std::string real_string = real ? "true" : "false";
std::string func_name;
if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str;
func_name = "bluestein_fft";
} else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str;
func_name = "rader_fft";
} else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string;
func_name = "four_step_fft";
} else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str;
func_name = "fft";
}
std::string base_name = kname.str();
// We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str();
auto kernel = get_fft_kernel(
d,
base_name,
hash_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real,
func_consts);
auto template_def = func_name == "four_step_fft" ? get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real)
: get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str);
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);

View File

@ -1,53 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view rader_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view bluestein_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view four_step_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";

View File

@ -18,6 +18,7 @@ const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* quantized();
const char* ternary();
const char* scan();
const char* softmax();

View File

@ -1,12 +1,10 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/fft.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
@ -494,47 +492,32 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
if (lib_name.find("bluestein") != std::string::npos) {
kernel_string = bluestein_fft_kernel;
} else if (lib_name.find("rader") != std::string::npos) {
kernel_string = rader_fft_kernel;
} else if (lib_name.find("four_step") != std::string::npos) {
kernel_string = four_step_fft_kernel;
} else {
kernel_string = fft_kernel;
}
kernel_source << metal::fft();
if (lib_name.find("four_step") != std::string::npos) {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type,
"step"_a = step,
"real"_a = real);
} else {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type);
}
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

View File

@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
@ -159,11 +161,34 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts);
const metal::MTLFCList& func_consts,
const std::string& template_def);
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string
get_template_definition(std::string name, std::string func, Args... args) {
std::ostringstream s;
s << func << "<";
bool first = true;
auto add_arg = [&s, &first](const auto& arg) {
if (!first) {
s << ", ";
}
first = false;
s << arg;
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
}
} // namespace mlx::core

View File

@ -12,9 +12,7 @@ set(
KERNELS
"arg_reduce"
"conv"
"fft"
"gemv"
"quantized"
"random"
"rms_norm"
"layer_norm"
@ -32,6 +30,8 @@ set(
"unary"
"ternary"
"copy"
"fft"
"quantized"
"softmax"
"sort"
"scan"
@ -51,6 +51,7 @@ set(
fft.h
fft/radix.h
fft/readwrite.h
quantized.h
softmax.h
sort.h
scan.h

View File

@ -13,3 +13,11 @@ static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
// Instantiate a templated kernel.
// Extra args are used as template parameters:
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
// [[host_name(binary_int)]] [kernel] binary<a, b>
#define instantiate_kernel(name, func, ...) \
template [[host_name( \
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;

View File

@ -1,58 +1,41 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/fft.h"
#define instantiate_fft(tg_mem_size, in_T, out_T) \
template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_fft(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
fft, \
tg_mem_size, \
in_T, \
out_T)
#define instantiate_rader(tg_mem_size, in_T, out_T) \
template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
rader_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* raders_b_q [[buffer(2)]], \
const device short* raders_g_q [[buffer(3)]], \
const device short* raders_g_minus_q [[buffer(4)]], \
constant const int& n, \
constant const int& batch_size, \
constant const int& rader_n, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rader(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
rader_fft, \
tg_mem_size, \
in_T, \
out_T)
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
bluestein_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* w_q [[buffer(2)]], \
const device float2* w_k [[buffer(3)]], \
constant const int& length, \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
bluestein_fft, \
tg_mem_size, \
in_T, \
out_T)
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \
"_" #step "_" #real)]] [[kernel]] void \
four_step_fft<tg_mem_size, in_T, out_T, step, real>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
constant const int& n1, \
constant const int& n2, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
instantiate_kernel( \
"four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \
four_step_fft, \
tg_mem_size, \
in_T, \
out_T, \
step, \
real)
// clang-format off
#define instantiate_ffts(tg_mem_size) \

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -195,13 +195,16 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&) {
return d.get_kernel(kernel_name);
}
} // namespace mlx::core

View File

@ -2,8 +2,10 @@
#include <cassert>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@ -44,12 +46,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@ -71,12 +76,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmv kernel
else if (B < 6) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@ -98,12 +105,16 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_t kernel
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@ -129,12 +140,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
auto type_string = get_type_string(x.dtype());
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
@ -156,12 +169,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_n kernel
else {
std::ostringstream kname;
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@ -253,12 +269,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@ -295,12 +314,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (B < 6) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@ -338,12 +360,16 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qmm_t
else {
std::ostringstream kname;
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@ -385,12 +411,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(out.dtype());
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
@ -428,12 +457,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to bs_qmm_n
else {
std::ostringstream kname;
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_;
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;