Batched Quantized Matmul + Fast Small QMV (#1503)

* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
This commit is contained in:
Alex Barron
2024-10-21 16:23:17 -07:00
committed by GitHub
parent 58a855682c
commit d15fa13daf
9 changed files with 866 additions and 761 deletions

View File

@@ -12,231 +12,29 @@
namespace mlx::core {
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
void launch_qmm(
std::string name,
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
MTL::Size& group_dims,
MTL::Size& grid_dims,
bool batched,
bool matrix,
bool gather,
bool aligned,
bool quad,
const Stream& s) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
int D = x.shape(-1);
int B = x.size() / D;
int O = out.shape(-1);
if (transpose_) {
// Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmv kernel
else if (B < 6) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_t kernel
else {
std::ostringstream kname;
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
else {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
// TODO: collapse batch dims
auto& batch_shape = lhs_indices.shape();
int batch_ndims = batch_shape.size();
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
@@ -266,256 +64,205 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s_strides = scales.strides();
auto& b_strides = biases.strides();
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
if (quad) {
kname << "_d_" << D;
}
if (aligned) {
kname << "_alN_" << aligned_n;
}
if (!gather) {
kname << "_batch_" << batched;
}
// Encode and dispatch kernel
std::string template_def;
if (quad) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched);
} else if (aligned && !gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
} else if (!gather && !aligned) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched);
} else if (aligned && gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n);
} else {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits);
}
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
int offset = 7;
if (matrix) {
compute_encoder->setBytes(&B, sizeof(int), 7);
offset += 1;
}
if (batched || gather) {
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
set_vector_bytes(compute_encoder, x_shape, offset + 1);
set_vector_bytes(compute_encoder, x_strides, offset + 2);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
set_vector_bytes(compute_encoder, w_shape, offset + 4);
set_vector_bytes(compute_encoder, w_strides, offset + 5);
set_vector_bytes(compute_encoder, s_strides, offset + 6);
set_vector_bytes(compute_encoder, b_strides, offset + 7);
}
if (gather) {
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
// TODO: collapse batch dims
auto& batch_shape = lhs_indices.shape();
int batch_ndims = batch_shape.size();
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
compute_encoder.set_input_array(lhs_indices, offset + 10);
compute_encoder.set_input_array(rhs_indices, offset + 11);
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
}
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void qmm_op(
const std::vector<array>& inputs,
array& out,
bool transpose,
int group_size,
int bits,
bool gather,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
MTL::Size group_dims;
MTL::Size grid_dims;
auto& x = inputs[0];
auto& w = inputs[1];
bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous);
int D = x.shape(-1);
int B = x.shape(-2);
int O = out.shape(-1);
int N = out.size() / B / O;
if (transpose_) {
// Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// For the unbatched W case, avoid `adjust_matrix_offsets`
// for a small performance gain.
int B = (batched || gather) ? x.shape(-2) : x.size() / D;
int N = (batched || gather) ? out.size() / B / O : 1;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
std::string name = gather ? "bs_" : "";
bool matrix = false;
bool aligned = false;
bool quad = false;
if (transpose) {
if (B < 6 && (D == 128 || D == 64)) {
name += "qmv_quad";
constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8;
int bo = quads_per_simd * results_per_quadgroup;
int simdgroup_size = 32;
group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
quad = true;
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast";
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
else if (B < 6) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
} else if (B < 6) {
name += "qmv";
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the bs_qmm_t
else {
std::ostringstream kname;
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
} else {
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
group_dims = MTL::Size(32, wn, wm);
grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
name += "qmm_t";
matrix = true;
aligned = true;
}
} else {
// Route to the bs_qvm kernel
if (B < 4) {
std::ostringstream kname;
auto type_string = get_type_string(out.dtype());
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
name += "qvm";
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to bs_qmm_n
else {
std::ostringstream kname;
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
} else {
name += "qmm_n";
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
group_dims = MTL::Size(32, wn, wm);
grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
matrix = true;
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
launch_qmm(
name,
inputs,
out,
group_size,
bits,
D,
O,
B,
N,
group_dims,
grid_dims,
batched,
matrix,
gather,
aligned,
quad,
s);
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
}
void fast::AffineQuantize::eval_gpu(