mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Attempt different packing
This commit is contained in:
parent
a06c968f4d
commit
e4b587819c
@ -5,11 +5,11 @@ from functools import partial
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
D = 16384
|
||||
D = 8192
|
||||
group_size = 64
|
||||
bits = 3
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
loops = 100
|
||||
|
||||
|
||||
def qmv_(x, wq, q_type):
|
||||
|
@ -2180,34 +2180,28 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g =
|
||||
in_vec_size * results_per_simdgroup * 2 / group_size;
|
||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = scales_row * results_per_simdgroup;
|
||||
const int out_vec_size_g = 2 * out_vec_size;
|
||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
const int scales_blocksize = (block_size / group_size) * out_vec_size_g;
|
||||
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += scales_row * in_vec_size_g +
|
||||
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
||||
scales += (simd_lid / scale_step_per_thread) * out_vec_size_g + 2 * out_row;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
U sb[2 * results_per_simdgroup];
|
||||
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
|
||||
sb[i] = scales[i];
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
result[row] += qdot<U, values_per_thread, bits>(
|
||||
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum);
|
||||
wl, x_thread, scales[2 * row + 0], scales[2 * row + 1], sum);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size * 2 * results_per_simdgroup / group_size;
|
||||
x += block_size;
|
||||
scales += scales_blocksize;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
|
@ -433,6 +433,16 @@ void affine_packed_qmv(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void affine_packed_qmm(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int B,
|
||||
int D,
|
||||
int O,
|
||||
int group_size,
|
||||
int bits,
|
||||
const Stream& s) {}
|
||||
|
||||
void affine_packed_qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@ -451,6 +461,7 @@ void affine_packed_qmm_op(
|
||||
if (B < 6) {
|
||||
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||
} else {
|
||||
affine_packed_qmm(inputs, out, B, D, O, group_size, bits, s);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
23
mlx/ops.cpp
23
mlx/ops.cpp
@ -128,9 +128,14 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
}
|
||||
|
||||
int weight_dims = w.shape(-1) * 32 / bits;
|
||||
int scales_dims = scales.shape(-1) * group_size;
|
||||
if (type == QuantizationType::AffinePacked) {
|
||||
scales_dims /= 8;
|
||||
int scales_dims;
|
||||
switch (type) {
|
||||
case QuantizationType::Affine:
|
||||
scales_dims = scales.shape(-1) * group_size;
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
scales_dims = scales.shape(-2) * group_size;
|
||||
break;
|
||||
}
|
||||
|
||||
if (weight_dims != scales_dims) {
|
||||
@ -3782,14 +3787,12 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
|
||||
// Pack scales and biases
|
||||
if (type == QuantizationType::AffinePacked) {
|
||||
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
||||
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
||||
scales = concatenate({scales, biases}, -2, s);
|
||||
scales = flatten(scales, -3, -2, s);
|
||||
scales = moveaxis(scales, -2, -1, s);
|
||||
scales = flatten(scales, -2, -1, s);
|
||||
array packed_scales_biases =
|
||||
flatten(stack({scales, biases}, -2, s), -3, -2, s);
|
||||
packed_scales_biases =
|
||||
contiguous(moveaxis(packed_scales_biases, -2, -1, s), false, s);
|
||||
|
||||
return std::make_tuple(wq, scales, std::nullopt);
|
||||
return std::make_tuple(wq, packed_scales_biases, std::nullopt);
|
||||
} else {
|
||||
return std::make_tuple(wq, scales, biases);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user