mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
tune quant dispatch (#2031)
This commit is contained in:
parent
c23888acd7
commit
9ba81e3da4
@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
|
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||||
|
|
||||||
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
||||||
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||||
x += tid.y * in_vec_size + quad_lid * values_per_thread;
|
x += tid.x * in_vec_size + quad_lid * values_per_thread;
|
||||||
y += tid.y * out_vec_size + out_row;
|
y += tid.x * out_vec_size + out_row;
|
||||||
|
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
|
@ -298,23 +298,69 @@ void qmm_op(
|
|||||||
bool aligned = false;
|
bool aligned = false;
|
||||||
bool quad = false;
|
bool quad = false;
|
||||||
|
|
||||||
|
auto get_qmv_batch_limit = [s](int D, int O) {
|
||||||
|
auto arch = metal::device(s.device).get_architecture();
|
||||||
|
auto arch_size = arch.back();
|
||||||
|
auto arch_gen = arch.substr(arch.size() - 3, 2);
|
||||||
|
if (arch_gen == "13" || arch_gen == "14") {
|
||||||
|
switch (arch_size) {
|
||||||
|
case 'd':
|
||||||
|
if (D <= 2048 && O <= 2048) {
|
||||||
|
return 32;
|
||||||
|
} else if (D <= 4096 && O <= 4096) {
|
||||||
|
return 18;
|
||||||
|
} else {
|
||||||
|
return 12;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if (D <= 2048 && O <= 2048) {
|
||||||
|
return 14;
|
||||||
|
} else if (D <= 4096 && O <= 4096) {
|
||||||
|
return 10;
|
||||||
|
} else {
|
||||||
|
return 6;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch (arch_size) {
|
||||||
|
case 'd':
|
||||||
|
if (D <= 2048 && O <= 2048) {
|
||||||
|
return 32;
|
||||||
|
} else if (D <= 4096 && O <= 4096) {
|
||||||
|
return 18;
|
||||||
|
} else {
|
||||||
|
return 12;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if (D <= 2048 && O <= 2048) {
|
||||||
|
return 18;
|
||||||
|
} else if (D <= 4096 && O <= 4096) {
|
||||||
|
return 12;
|
||||||
|
} else {
|
||||||
|
return 10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if (transpose) {
|
if (transpose) {
|
||||||
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
auto qmv_batch_limit = get_qmv_batch_limit(D, O);
|
||||||
|
if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
||||||
name += "qmv_quad";
|
name += "qmv_quad";
|
||||||
constexpr int quads_per_simd = 8;
|
constexpr int quads_per_simd = 8;
|
||||||
constexpr int results_per_quadgroup = 8;
|
constexpr int results_per_quadgroup = 8;
|
||||||
int bo = quads_per_simd * results_per_quadgroup;
|
int bo = quads_per_simd * results_per_quadgroup;
|
||||||
int simdgroup_size = 32;
|
int simdgroup_size = 32;
|
||||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
||||||
quad = true;
|
quad = true;
|
||||||
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
name += "qmv_fast";
|
name += "qmv_fast";
|
||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
group_dims = MTL::Size(bd, 2, 1);
|
group_dims = MTL::Size(bd, 2, 1);
|
||||||
grid_dims = MTL::Size(B, O / bo, N);
|
grid_dims = MTL::Size(B, O / bo, N);
|
||||||
} else if (B < 6) {
|
} else if (B < qmv_batch_limit) {
|
||||||
name += "qmv";
|
name += "qmv";
|
||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
|
Loading…
Reference in New Issue
Block a user