mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improved quantized matrix vector product (#786)
This commit is contained in:
committed by
GitHub
parent
cbcf44a4ca
commit
14b4e51a7c
@@ -41,8 +41,35 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
if (B < 6) {
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
@@ -52,9 +79,9 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = std::min(32, O);
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
|
||||
Reference in New Issue
Block a user