mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation * Add qmm_n * Add gradient wrt to input for quantized_matmul
This commit is contained in:

committed by
GitHub

parent
d7ac050f4b
commit
e7f5059fe4
@@ -104,6 +104,108 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[BM];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result[el_per_int] = {0};
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = x[simd_lid + i];
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
y[k] = result[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
@@ -133,8 +235,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
constexpr int groups_per_simd = BN / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
// Using the kernel just as a type to instantiate the appropriate BlockMMA
|
||||
// and constexpr size calculations
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
|
||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
||||
|
||||
@@ -231,8 +332,133 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
const uint lidy = lid / SIMD_SIZE;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1;
|
||||
constexpr int groups_per_simd = BK / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
|
||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
||||
|
||||
threadgroup T scales_block[BK * groups_per_block];
|
||||
threadgroup T biases_block[BK * groups_per_block];
|
||||
threadgroup T Xs[BM * BK];
|
||||
threadgroup T Ws[BK * BN];
|
||||
|
||||
// Set the block
|
||||
const int N_w = N / el_per_int;
|
||||
const int N_g = N / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
w += y_col / el_per_int;
|
||||
scales += y_col / group_size;
|
||||
biases += y_col / group_size;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k=0; k<K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load the x tile
|
||||
if (num_els < BM) {
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
} else {
|
||||
loader_x.load_unsafe();
|
||||
}
|
||||
|
||||
// Load the scale and bias
|
||||
if (simd_lid == 0) {
|
||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * N_g;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * N_g;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gc=0; gc<groups_per_block; gc++) {
|
||||
scales_block_local[gc] = scales_local[gc];
|
||||
biases_block_local[gc] = biases_local[gc];
|
||||
}
|
||||
scales_block_local += groups_per_block;
|
||||
scales_local += N_g;
|
||||
biases_block_local += groups_per_block;
|
||||
biases_local += N_g;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(Xs, Ws);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_x.next();
|
||||
w += BK * N_w;
|
||||
scales += BK * N_g;
|
||||
biases += BK * N_g;
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM) {
|
||||
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
@@ -258,6 +484,33 @@ instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qvm_types(group_size, bits) \
|
||||
instantiate_qvm(float32, float, group_size, bits) \
|
||||
instantiate_qvm(float16, half, group_size, bits) \
|
||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qvm_types(128, 2)
|
||||
instantiate_qvm_types(128, 4)
|
||||
instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
@@ -285,3 +538,31 @@ instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_n_types(group_size, bits) \
|
||||
instantiate_qmm_n(float32, float, group_size, bits) \
|
||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_n_types(128, 2)
|
||||
instantiate_qmm_n_types(128, 4)
|
||||
instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -23,97 +22,147 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto [x_transposed, x_cols, x] = check_transpose(x_pre);
|
||||
auto [w_transposed, w_cols, w] = check_transpose(w_pre);
|
||||
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre);
|
||||
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre);
|
||||
|
||||
if (!w_transposed) {
|
||||
throw std::runtime_error("The quantized weight should be transposed.");
|
||||
}
|
||||
|
||||
if (x_transposed || scales_transposed || biases_transposed) {
|
||||
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||
}
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the qmv kernel
|
||||
if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Route to the qmv kernel
|
||||
if (B == 1) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int bo = 32;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
int O = w.size() / w_cols;
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
int bo = 32;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
// Route to the qmm_t kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Route to the qmm kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
int O = w.size() / w_cols;
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
int bo = 32;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 64;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
Reference in New Issue
Block a user