mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims * fix test * batched cpu * add batched template param * refactor metal quantized.cpp
This commit is contained in:
@@ -12,231 +12,29 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
void launch_qmm(
|
||||
std::string name,
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
int O,
|
||||
int B,
|
||||
int N,
|
||||
MTL::Size& group_dims,
|
||||
MTL::Size& grid_dims,
|
||||
bool batched,
|
||||
bool matrix,
|
||||
bool gather,
|
||||
bool aligned,
|
||||
bool quad,
|
||||
const Stream& s) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_t kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
// TODO: collapse batch dims
|
||||
auto& batch_shape = lhs_indices.shape();
|
||||
int batch_ndims = batch_shape.size();
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
// Ensure that the last two dims are row contiguous.
|
||||
// TODO: Check if we really need this for x as well...
|
||||
std::vector<array> copies;
|
||||
@@ -266,256 +64,205 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s_strides = scales.strides();
|
||||
auto& b_strides = biases.strides();
|
||||
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
||||
if (quad) {
|
||||
kname << "_d_" << D;
|
||||
}
|
||||
if (aligned) {
|
||||
kname << "_alN_" << aligned_n;
|
||||
}
|
||||
if (!gather) {
|
||||
kname << "_batch_" << batched;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
std::string template_def;
|
||||
if (quad) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, D, batched);
|
||||
} else if (aligned && !gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
|
||||
} else if (!gather && !aligned) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, batched);
|
||||
} else if (aligned && gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits);
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
int offset = 7;
|
||||
if (matrix) {
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
offset += 1;
|
||||
}
|
||||
|
||||
if (batched || gather) {
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
|
||||
set_vector_bytes(compute_encoder, x_shape, offset + 1);
|
||||
set_vector_bytes(compute_encoder, x_strides, offset + 2);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
|
||||
set_vector_bytes(compute_encoder, w_shape, offset + 4);
|
||||
set_vector_bytes(compute_encoder, w_strides, offset + 5);
|
||||
set_vector_bytes(compute_encoder, s_strides, offset + 6);
|
||||
set_vector_bytes(compute_encoder, b_strides, offset + 7);
|
||||
}
|
||||
if (gather) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
// TODO: collapse batch dims
|
||||
auto& batch_shape = lhs_indices.shape();
|
||||
int batch_ndims = batch_shape.size();
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
|
||||
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
|
||||
compute_encoder.set_input_array(lhs_indices, offset + 10);
|
||||
compute_encoder.set_input_array(rhs_indices, offset + 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
|
||||
}
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool gather,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.shape(-2);
|
||||
int O = out.shape(-1);
|
||||
int N = out.size() / B / O;
|
||||
if (transpose_) {
|
||||
// Route to the fast bs_qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
// For the unbatched W case, avoid `adjust_matrix_offsets`
|
||||
// for a small performance gain.
|
||||
int B = (batched || gather) ? x.shape(-2) : x.size() / D;
|
||||
int N = (batched || gather) ? out.size() / B / O : 1;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
std::string name = gather ? "bs_" : "";
|
||||
bool matrix = false;
|
||||
bool aligned = false;
|
||||
bool quad = false;
|
||||
|
||||
if (transpose) {
|
||||
if (B < 6 && (D == 128 || D == 64)) {
|
||||
name += "qmv_quad";
|
||||
constexpr int quads_per_simd = 8;
|
||||
constexpr int results_per_quadgroup = 8;
|
||||
int bo = quads_per_simd * results_per_quadgroup;
|
||||
int simdgroup_size = 32;
|
||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
quad = true;
|
||||
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
name += "qmv_fast";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
} else if (B < 6) {
|
||||
name += "qmv";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the bs_qmm_t
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
} else {
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
group_dims = MTL::Size(32, wn, wm);
|
||||
grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
name += "qmm_t";
|
||||
matrix = true;
|
||||
aligned = true;
|
||||
}
|
||||
} else {
|
||||
// Route to the bs_qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
name += "qvm";
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to bs_qmm_n
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
} else {
|
||||
name += "qmm_n";
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
group_dims = MTL::Size(32, wn, wm);
|
||||
grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||
matrix = true;
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
launch_qmm(
|
||||
name,
|
||||
inputs,
|
||||
out,
|
||||
group_size,
|
||||
bits,
|
||||
D,
|
||||
O,
|
||||
B,
|
||||
N,
|
||||
group_dims,
|
||||
grid_dims,
|
||||
batched,
|
||||
matrix,
|
||||
gather,
|
||||
aligned,
|
||||
quad,
|
||||
s);
|
||||
}
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
|
||||
Reference in New Issue
Block a user