Another packing

This commit is contained in:
Angelos Katharopoulos 2024-12-13 23:48:25 -08:00
parent cb358dbdda
commit 05cb54ae3f
2 changed files with 38 additions and 28 deletions

View File

@ -2162,7 +2162,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int packs_per_thread = 1;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
@ -2179,14 +2179,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w =
in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor;
const int in_vec_size_g = const int in_vec_size_g =
in_vec_size * results_per_simdgroup * 2 / group_size; in_vec_size * results_per_simdgroup * 2 / group_size;
const int scales_row = tid.x * num_simdgroups + simd_gid; const int w_row = tid.x * num_simdgroups + simd_gid;
const int out_row = scales_row * results_per_simdgroup; const int out_row = w_row * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; ws += w_row * in_vec_size_w +
scales += scales_row * in_vec_size_g + simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack;
scales += w_row * in_vec_size_g +
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.y * out_vec_size + out_row;
@ -2194,18 +2196,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
for (int k = 0; k < in_vec_size; k += block_size) { for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
U sb[2 * results_per_simdgroup];
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
sb[i] = scales[i];
}
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
result[row] += qdot<U, values_per_thread, bits>( result[row] += qdot<U, values_per_thread, bits>(
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); ws + row * bytes_per_pack,
x_thread,
scales[2 * row + 0],
scales[2 * row + 1],
sum);
} }
ws += block_size * bytes_per_pack / pack_factor; ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor;
scales += block_size * 2 * results_per_simdgroup / group_size; scales += block_size * 2 * results_per_simdgroup / group_size;
x += block_size; x += block_size;
} }

View File

@ -131,6 +131,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
int scales_dims = scales.shape(-1) * group_size; int scales_dims = scales.shape(-1) * group_size;
if (type == QuantizationType::AffinePacked) { if (type == QuantizationType::AffinePacked) {
scales_dims /= 8; scales_dims /= 8;
weight_dims /= 4;
} }
if (weight_dims != scales_dims) { if (weight_dims != scales_dims) {
@ -147,8 +148,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
int x_inner_dims = x.shape(-1); int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims // Calculate the expanded w's dims
int w_inner_dims = (transpose) ? weight_dims : w.shape(-2); int weight_dims_other = w.shape(-2);
int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims; if (type == QuantizationType::AffinePacked) {
weight_dims_other *= 4;
}
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
int w_outer_dims = (transpose) ? weight_dims_other : weight_dims;
if (w_inner_dims != x_inner_dims) { if (w_inner_dims != x_inner_dims) {
std::ostringstream msg; std::ostringstream msg;
@ -3778,10 +3783,12 @@ std::tuple<array, array, std::optional<array>> quantize(
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
switch (type) {
case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: {
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
// Pack scales and biases
if (type == QuantizationType::AffinePacked) {
scales = unflatten(scales, -2, {-1, 4, 1}, s); scales = unflatten(scales, -2, {-1, 4, 1}, s);
biases = unflatten(biases, -2, {-1, 4, 1}, s); biases = unflatten(biases, -2, {-1, 4, 1}, s);
scales = concatenate({scales, biases}, -2, s); scales = concatenate({scales, biases}, -2, s);
@ -3789,9 +3796,12 @@ std::tuple<array, array, std::optional<array>> quantize(
scales = moveaxis(scales, -2, -1, s); scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s); scales = flatten(scales, -2, -1, s);
wq = unflatten(wq, -2, {-1, 4}, s);
wq = moveaxis(wq, -2, -1, s);
wq = flatten(wq, -2, -1, s);
return std::make_tuple(wq, scales, std::nullopt); return std::make_tuple(wq, scales, std::nullopt);
} else { }
return std::make_tuple(wq, scales, biases);
} }
} }