Attempt different packing

This commit is contained in:
Angelos Katharopoulos 2024-12-13 18:36:36 -08:00
parent a06c968f4d
commit e4b587819c
4 changed files with 33 additions and 25 deletions

View File

@ -5,11 +5,11 @@ from functools import partial
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
D = 16384 D = 8192
group_size = 64 group_size = 64
bits = 3 bits = 3
dtype = mx.float16 dtype = mx.float16
loops = 10 loops = 100
def qmv_(x, wq, q_type): def qmv_(x, wq, q_type):

View File

@ -2180,34 +2180,28 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = const int out_vec_size_g = 2 * out_vec_size;
in_vec_size * results_per_simdgroup * 2 / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
const int scales_row = tid.x * num_simdgroups + simd_gid; simd_gid * results_per_simdgroup;
const int out_row = scales_row * results_per_simdgroup; const int scales_blocksize = (block_size / group_size) * out_vec_size_g;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += scales_row * in_vec_size_g + scales += (simd_lid / scale_step_per_thread) * out_vec_size_g + 2 * out_row;
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) { for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
U sb[2 * results_per_simdgroup];
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
sb[i] = scales[i];
}
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
result[row] += qdot<U, values_per_thread, bits>( result[row] += qdot<U, values_per_thread, bits>(
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); wl, x_thread, scales[2 * row + 0], scales[2 * row + 1], sum);
} }
ws += block_size * bytes_per_pack / pack_factor; ws += block_size * bytes_per_pack / pack_factor;
scales += block_size * 2 * results_per_simdgroup / group_size;
x += block_size; x += block_size;
scales += scales_blocksize;
} }
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {

View File

@ -433,6 +433,16 @@ void affine_packed_qmv(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void affine_packed_qmm(
const std::vector<array>& inputs,
array& out,
int B,
int D,
int O,
int group_size,
int bits,
const Stream& s) {}
void affine_packed_qmm_op( void affine_packed_qmm_op(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
@ -451,6 +461,7 @@ void affine_packed_qmm_op(
if (B < 6) { if (B < 6) {
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s); affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
} else { } else {
affine_packed_qmm(inputs, out, B, D, O, group_size, bits, s);
} }
} else { } else {
} }

View File

@ -128,9 +128,14 @@ std::pair<int, int> extract_quantized_matmul_dims(
} }
int weight_dims = w.shape(-1) * 32 / bits; int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size; int scales_dims;
if (type == QuantizationType::AffinePacked) { switch (type) {
scales_dims /= 8; case QuantizationType::Affine:
scales_dims = scales.shape(-1) * group_size;
break;
case QuantizationType::AffinePacked:
scales_dims = scales.shape(-2) * group_size;
break;
} }
if (weight_dims != scales_dims) { if (weight_dims != scales_dims) {
@ -3782,14 +3787,12 @@ std::tuple<array, array, std::optional<array>> quantize(
// Pack scales and biases // Pack scales and biases
if (type == QuantizationType::AffinePacked) { if (type == QuantizationType::AffinePacked) {
scales = unflatten(scales, -2, {-1, 4, 1}, s); array packed_scales_biases =
biases = unflatten(biases, -2, {-1, 4, 1}, s); flatten(stack({scales, biases}, -2, s), -3, -2, s);
scales = concatenate({scales, biases}, -2, s); packed_scales_biases =
scales = flatten(scales, -3, -2, s); contiguous(moveaxis(packed_scales_biases, -2, -1, s), false, s);
scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s);
return std::make_tuple(wq, scales, std::nullopt); return std::make_tuple(wq, packed_scales_biases, std::nullopt);
} else { } else {
return std::make_tuple(wq, scales, biases); return std::make_tuple(wq, scales, biases);
} }