mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops * remove unused imports * address comments * fix imports * second attempt to fix imports --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -44,12 +46,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||
<< "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
@@ -71,12 +76,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the qmv kernel
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
@@ -98,12 +105,16 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the qmm_t kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
@@ -129,12 +140,14 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 64;
|
||||
@@ -156,12 +169,15 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the qmm_n kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
@@ -253,12 +269,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the fast bs_qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
@@ -295,12 +314,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
@@ -338,12 +360,16 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the bs_qmm_t
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
|
||||
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
@@ -385,12 +411,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to the bs_qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 64;
|
||||
@@ -428,12 +457,15 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Route to bs_qmm_n
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
|
||||
<< "_b_" << bits_;
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
|
||||
Reference in New Issue
Block a user