Vectorized reads

This commit is contained in:
Angelos Katharopoulos 2024-12-14 15:08:24 -08:00
parent 05cb54ae3f
commit d7ed624502
2 changed files with 63 additions and 31 deletions

View File

@ -2150,10 +2150,41 @@ template <typename T, const int group_size, const int bits>
}
}
template <typename T, typename U, int bits>
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
vec<U, 4> accum = 0;
if (bits == 4) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
for (int j = 0; j < 2; j++) {
accum[i] +=
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
}
}
}
else if (bits == 8) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
for (int j = 0; j < 4; j++) {
accum[i] += x[j] * ws[j];
}
}
}
return accum;
}
template <typename T, int group_size, int bits>
METAL_FUNC void affine_packed_qmv_fast_impl(
const device uint32_t* w,
const device T* scales,
const device vec<uint32_t, 4>* w,
const device vec<T, 4>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@ -2162,7 +2193,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = 1;
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
@ -2171,48 +2202,50 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w =
in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor;
const int in_vec_size_g =
in_vec_size * results_per_simdgroup * 2 / group_size;
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size * 2 / group_size;
const int w_row = tid.x * num_simdgroups + simd_gid;
const int out_row = w_row * results_per_simdgroup;
ws += w_row * in_vec_size_w +
simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack;
scales += w_row * in_vec_size_g +
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] += qdot<U, values_per_thread, bits>(
ws + row * bytes_per_pack,
x_thread,
scales[2 * row + 0],
scales[2 * row + 1],
sum);
// Load the scales and biases
vec<T, 4> s = scales[0];
vec<T, 4> b = scales[1];
// Load the weights and perform the partial dot product
vec<U, 4> accum = 0;
for (int pack = 0; pack < packs_per_thread; pack++) {
accum +=
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
}
ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor;
scales += block_size * 2 * results_per_simdgroup / group_size;
// Finalize the dot product and accumulate it
for (int i = 0; i < 4; i++) {
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
}
w += block_size / pack_factor;
scales += 2 * block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
result = simd_sum(result);
if (simd_lid == 0) {
for (int row = 0; row < results_per_simdgroup; row++) {
y[row] = static_cast<T>(result[row]);
}
}
@ -2220,8 +2253,8 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
template <typename T, int group_size, int bits>
[[kernel]] void affine_packed_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device vec<uint32_t, 4>* w [[buffer(0)]],
const device vec<T, 4>* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& in_vec_size [[buffer(5)]],

View File

@ -3789,10 +3789,9 @@ std::tuple<array, array, std::optional<array>> quantize(
case QuantizationType::AffinePacked: {
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
scales = unflatten(scales, -2, {-1, 4, 1}, s);
biases = unflatten(biases, -2, {-1, 4, 1}, s);
scales = unflatten(scales, -2, {-1, 4}, s);
biases = unflatten(biases, -2, {-1, 4}, s);
scales = concatenate({scales, biases}, -2, s);
scales = flatten(scales, -3, -2, s);
scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s);