mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Attempt different packing
This commit is contained in:
parent
a06c968f4d
commit
e4b587819c
@ -5,11 +5,11 @@ from functools import partial
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
D = 16384
|
D = 8192
|
||||||
group_size = 64
|
group_size = 64
|
||||||
bits = 3
|
bits = 3
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
loops = 10
|
loops = 100
|
||||||
|
|
||||||
|
|
||||||
def qmv_(x, wq, q_type):
|
def qmv_(x, wq, q_type):
|
||||||
|
@ -2180,34 +2180,28 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
|||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
const int in_vec_size_g =
|
const int out_vec_size_g = 2 * out_vec_size;
|
||||||
in_vec_size * results_per_simdgroup * 2 / group_size;
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
simd_gid * results_per_simdgroup;
|
||||||
const int out_row = scales_row * results_per_simdgroup;
|
const int scales_blocksize = (block_size / group_size) * out_vec_size_g;
|
||||||
|
|
||||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
scales += scales_row * in_vec_size_g +
|
scales += (simd_lid / scale_step_per_thread) * out_vec_size_g + 2 * out_row;
|
||||||
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
|
||||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
y += tid.y * out_vec_size + out_row;
|
y += tid.y * out_vec_size + out_row;
|
||||||
|
|
||||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
U sb[2 * results_per_simdgroup];
|
|
||||||
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
|
|
||||||
sb[i] = scales[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
result[row] += qdot<U, values_per_thread, bits>(
|
result[row] += qdot<U, values_per_thread, bits>(
|
||||||
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum);
|
wl, x_thread, scales[2 * row + 0], scales[2 * row + 1], sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size * 2 * results_per_simdgroup / group_size;
|
|
||||||
x += block_size;
|
x += block_size;
|
||||||
|
scales += scales_blocksize;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
@ -433,6 +433,16 @@ void affine_packed_qmv(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void affine_packed_qmm(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
int B,
|
||||||
|
int D,
|
||||||
|
int O,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const Stream& s) {}
|
||||||
|
|
||||||
void affine_packed_qmm_op(
|
void affine_packed_qmm_op(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
@ -451,6 +461,7 @@ void affine_packed_qmm_op(
|
|||||||
if (B < 6) {
|
if (B < 6) {
|
||||||
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||||
} else {
|
} else {
|
||||||
|
affine_packed_qmm(inputs, out, B, D, O, group_size, bits, s);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
}
|
}
|
||||||
|
23
mlx/ops.cpp
23
mlx/ops.cpp
@ -128,9 +128,14 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
int weight_dims = w.shape(-1) * 32 / bits;
|
int weight_dims = w.shape(-1) * 32 / bits;
|
||||||
int scales_dims = scales.shape(-1) * group_size;
|
int scales_dims;
|
||||||
if (type == QuantizationType::AffinePacked) {
|
switch (type) {
|
||||||
scales_dims /= 8;
|
case QuantizationType::Affine:
|
||||||
|
scales_dims = scales.shape(-1) * group_size;
|
||||||
|
break;
|
||||||
|
case QuantizationType::AffinePacked:
|
||||||
|
scales_dims = scales.shape(-2) * group_size;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (weight_dims != scales_dims) {
|
if (weight_dims != scales_dims) {
|
||||||
@ -3782,14 +3787,12 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
|
|
||||||
// Pack scales and biases
|
// Pack scales and biases
|
||||||
if (type == QuantizationType::AffinePacked) {
|
if (type == QuantizationType::AffinePacked) {
|
||||||
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
array packed_scales_biases =
|
||||||
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
flatten(stack({scales, biases}, -2, s), -3, -2, s);
|
||||||
scales = concatenate({scales, biases}, -2, s);
|
packed_scales_biases =
|
||||||
scales = flatten(scales, -3, -2, s);
|
contiguous(moveaxis(packed_scales_biases, -2, -1, s), false, s);
|
||||||
scales = moveaxis(scales, -2, -1, s);
|
|
||||||
scales = flatten(scales, -2, -1, s);
|
|
||||||
|
|
||||||
return std::make_tuple(wq, scales, std::nullopt);
|
return std::make_tuple(wq, packed_scales_biases, std::nullopt);
|
||||||
} else {
|
} else {
|
||||||
return std::make_tuple(wq, scales, biases);
|
return std::make_tuple(wq, scales, biases);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user