Add Quantized Ops to the JIT (#1204)

* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-06-12 09:47:12 -07:00
committed by GitHub
parent df964132fb
commit dd7d8e5e29
13 changed files with 1778 additions and 1948 deletions

View File

@@ -2,8 +2,10 @@
#include <cassert>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@@ -44,12 +46,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@@ -71,12 +76,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmv kernel
else if (B < 6) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@@ -98,12 +105,16 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_t kernel
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@@ -129,12 +140,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
auto type_string = get_type_string(x.dtype());
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
@@ -156,12 +169,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the qmm_n kernel
else {
std::ostringstream kname;
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@@ -253,12 +269,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@@ -295,12 +314,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (B < 6) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
@@ -338,12 +360,16 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qmm_t
else {
std::ostringstream kname;
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
@@ -385,12 +411,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to the bs_qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
auto type_string = get_type_string(out.dtype());
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
@@ -428,12 +457,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to bs_qmm_n
else {
std::ostringstream kname;
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_;
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto template_def = get_template_definition(
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;