mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 15:11:14 +08:00
Another packing
This commit is contained in:
parent
cb358dbdda
commit
05cb54ae3f
@ -2162,7 +2162,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
constexpr int packs_per_thread = 1;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
@ -2179,14 +2179,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
|||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w =
|
||||||
|
in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor;
|
||||||
const int in_vec_size_g =
|
const int in_vec_size_g =
|
||||||
in_vec_size * results_per_simdgroup * 2 / group_size;
|
in_vec_size * results_per_simdgroup * 2 / group_size;
|
||||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
const int w_row = tid.x * num_simdgroups + simd_gid;
|
||||||
const int out_row = scales_row * results_per_simdgroup;
|
const int out_row = w_row * results_per_simdgroup;
|
||||||
|
|
||||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
ws += w_row * in_vec_size_w +
|
||||||
scales += scales_row * in_vec_size_g +
|
simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack;
|
||||||
|
scales += w_row * in_vec_size_g +
|
||||||
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
||||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
y += tid.y * out_vec_size + out_row;
|
y += tid.y * out_vec_size + out_row;
|
||||||
@ -2194,18 +2196,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
|||||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
U sb[2 * results_per_simdgroup];
|
|
||||||
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
|
|
||||||
sb[i] = scales[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
||||||
result[row] += qdot<U, values_per_thread, bits>(
|
result[row] += qdot<U, values_per_thread, bits>(
|
||||||
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum);
|
ws + row * bytes_per_pack,
|
||||||
|
x_thread,
|
||||||
|
scales[2 * row + 0],
|
||||||
|
scales[2 * row + 1],
|
||||||
|
sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size * 2 * results_per_simdgroup / group_size;
|
scales += block_size * 2 * results_per_simdgroup / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
|
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -131,6 +131,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
int scales_dims = scales.shape(-1) * group_size;
|
int scales_dims = scales.shape(-1) * group_size;
|
||||||
if (type == QuantizationType::AffinePacked) {
|
if (type == QuantizationType::AffinePacked) {
|
||||||
scales_dims /= 8;
|
scales_dims /= 8;
|
||||||
|
weight_dims /= 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (weight_dims != scales_dims) {
|
if (weight_dims != scales_dims) {
|
||||||
@ -147,8 +148,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
int x_inner_dims = x.shape(-1);
|
int x_inner_dims = x.shape(-1);
|
||||||
|
|
||||||
// Calculate the expanded w's dims
|
// Calculate the expanded w's dims
|
||||||
int w_inner_dims = (transpose) ? weight_dims : w.shape(-2);
|
int weight_dims_other = w.shape(-2);
|
||||||
int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims;
|
if (type == QuantizationType::AffinePacked) {
|
||||||
|
weight_dims_other *= 4;
|
||||||
|
}
|
||||||
|
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
||||||
|
int w_outer_dims = (transpose) ? weight_dims_other : weight_dims;
|
||||||
|
|
||||||
if (w_inner_dims != x_inner_dims) {
|
if (w_inner_dims != x_inner_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3778,10 +3783,12 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
switch (type) {
|
||||||
|
case QuantizationType::Affine:
|
||||||
|
return fast::affine_quantize(w, group_size, bits, s);
|
||||||
|
case QuantizationType::AffinePacked: {
|
||||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||||
|
|
||||||
// Pack scales and biases
|
|
||||||
if (type == QuantizationType::AffinePacked) {
|
|
||||||
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
||||||
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
||||||
scales = concatenate({scales, biases}, -2, s);
|
scales = concatenate({scales, biases}, -2, s);
|
||||||
@ -3789,9 +3796,12 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
scales = moveaxis(scales, -2, -1, s);
|
scales = moveaxis(scales, -2, -1, s);
|
||||||
scales = flatten(scales, -2, -1, s);
|
scales = flatten(scales, -2, -1, s);
|
||||||
|
|
||||||
|
wq = unflatten(wq, -2, {-1, 4}, s);
|
||||||
|
wq = moveaxis(wq, -2, -1, s);
|
||||||
|
wq = flatten(wq, -2, -1, s);
|
||||||
|
|
||||||
return std::make_tuple(wq, scales, std::nullopt);
|
return std::make_tuple(wq, scales, std::nullopt);
|
||||||
} else {
|
}
|
||||||
return std::make_tuple(wq, scales, biases);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user