mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Vectorized reads
This commit is contained in:
parent
05cb54ae3f
commit
d7ed624502
@ -2150,10 +2150,41 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int bits>
|
||||
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
vec<U, 4> accum = 0;
|
||||
|
||||
if (bits == 4) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
||||
for (int j = 0; j < 2; j++) {
|
||||
accum[i] +=
|
||||
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
|
||||
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||
for (int j = 0; j < 4; j++) {
|
||||
accum[i] += x[j] * ws[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accum;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device vec<uint32_t, 4>* w,
|
||||
const device vec<T, 4>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@ -2162,7 +2193,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = 1;
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
@ -2171,48 +2202,50 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w =
|
||||
in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g =
|
||||
in_vec_size * results_per_simdgroup * 2 / group_size;
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size * 2 / group_size;
|
||||
const int w_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = w_row * results_per_simdgroup;
|
||||
|
||||
ws += w_row * in_vec_size_w +
|
||||
simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack;
|
||||
scales += w_row * in_vec_size_g +
|
||||
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
||||
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] += qdot<U, values_per_thread, bits>(
|
||||
ws + row * bytes_per_pack,
|
||||
x_thread,
|
||||
scales[2 * row + 0],
|
||||
scales[2 * row + 1],
|
||||
sum);
|
||||
// Load the scales and biases
|
||||
vec<T, 4> s = scales[0];
|
||||
vec<T, 4> b = scales[1];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
vec<U, 4> accum = 0;
|
||||
for (int pack = 0; pack < packs_per_thread; pack++) {
|
||||
accum +=
|
||||
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
|
||||
}
|
||||
|
||||
ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size * 2 * results_per_simdgroup / group_size;
|
||||
// Finalize the dot product and accumulate it
|
||||
for (int i = 0; i < 4; i++) {
|
||||
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += 2 * block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
result = simd_sum(result);
|
||||
if (simd_lid == 0) {
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
@ -2220,8 +2253,8 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void affine_packed_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device vec<uint32_t, 4>* w [[buffer(0)]],
|
||||
const device vec<T, 4>* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
|
@ -3789,10 +3789,9 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
case QuantizationType::AffinePacked: {
|
||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||
|
||||
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
||||
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
||||
scales = unflatten(scales, -2, {-1, 4}, s);
|
||||
biases = unflatten(biases, -2, {-1, 4}, s);
|
||||
scales = concatenate({scales, biases}, -2, s);
|
||||
scales = flatten(scales, -3, -2, s);
|
||||
scales = moveaxis(scales, -2, -1, s);
|
||||
scales = flatten(scales, -2, -1, s);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user