// Copyright © 2023-2024 Apple Inc. #include #include constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; template inline constexpr short get_pack_factor() { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } template inline constexpr short get_bytes_per_pack() { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); } template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < values_per_thread; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; x_thread[i] = x[i]; } } return sum; } template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } } else if (bits == 3) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 8.0f; x_thread[i + 2] = x[i + 2] / 64.0f; x_thread[i + 3] = x[i + 3] / 2.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 128.0f; x_thread[i + 6] = x[i + 6] / 4.0f; x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } } else if (bits == 5) { for (int i = 0; i < N; i += 8) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + x[i + 6] + x[i + 7]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 32.0f; x_thread[i + 2] = x[i + 2] / 4.0f; x_thread[i + 3] = x[i + 3] / 128.0f; x_thread[i + 4] = x[i + 4] / 16.0f; x_thread[i + 5] = x[i + 5] / 2.0f; x_thread[i + 6] = x[i + 6] / 64.0f; x_thread[i + 7] = x[i + 7] / 8.0f; } } else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; x_thread[i + 1] = x[i + 1] / 64.0f; x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { for (int i = 0; i < N; i++) { sum += x[i]; x_thread[i] = x[i]; } } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; } return sum; } template inline U qdot( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline U qdot_safe( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum, int N) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 3] * (w[i] & 0xc0)); } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 3 * i; accum += (w[0] & 0x07) * x_thread[0]; accum += (w[0] & 0x38) * x_thread[1]; accum += (w[0] & 0xc0) * x_thread[2]; accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); accum += (w[1] & 0x0e) * x_thread[3]; accum += (w[1] & 0x70) * x_thread[4]; accum += (w[1] & 0x80) * x_thread[5]; accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); accum += (w[2] & 0x1c) * x_thread[6]; accum += (w[2] & 0xe0) * x_thread[7]; } } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { x_thread += 8 * i; w += 5 * i; accum += (w[0] & 0x1f) * x_thread[0]; accum += (w[0] & 0xe0) * x_thread[1]; accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); accum += (w[1] & 0x7c) * x_thread[2]; accum += (w[1] & 0x80) * x_thread[3]; accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); accum += (w[2] & 0xf0) * x_thread[4]; accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); accum += (w[3] & 0x3e) * x_thread[5]; accum += (w[3] & 0xc0) * x_thread[6]; accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); accum += (w[4] & 0xf8) * x_thread[7]; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; w += 3 * i; accum += (w[0] & 0x3f) * x_thread[0]; accum += (w[0] & 0xc0) * x_thread[1]; accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); accum += (w[1] & 0xf0) * x_thread[2]; accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); accum += (w[2] & 0xfc) * x_thread[3]; } } else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; } } return scale * accum + sum * bias; } template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); } } else if (bits == 3) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[8 * i] += x * ((w0 & 0x7) * scale + bias); result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); result[8 * i + 2] += x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); result[8 * i + 5] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); } } else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } } else if (bits == 5) { for (int i = 0; i < (values_per_thread / 8); i++) { uint8_t w0 = w[5 * i]; uint8_t w1 = w[5 * i + 1]; uint8_t w2 = w[5 * i + 2]; uint8_t w3 = w[5 * i + 3]; uint8_t w4 = w[5 * i + 4]; result[8 * i] += x * ((w0 & 0x1f) * scale + bias); result[8 * i + 1] += x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); result[8 * i + 3] += x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); result[8 * i + 4] += x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); result[8 * i + 6] += x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); } } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; uint8_t w2 = w[3 * i + 2]; result[4 * i] += x * ((w0 & 0x3f) * scale + bias); result[4 * i + 1] += x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); result[4 * i + 2] += x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { result[i] += x * (scale * w[i] + bias); } } } template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), scale / static_cast(16.0f), scale / static_cast(64.0f)}; for (int i = 0; i < (N / 4); i++) { w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; } } else if (bits == 3) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 3 * i; w_local[0] = (w[0] & 0x7) * scale + bias; w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } } else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } else if (bits == 5) { for (int i = 0; i < (N / 8); i++) { w_local += 8 * i; w += 5 * i; w_local[0] = (w[0] & 0x1f) * scale + bias; w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; } } else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; } } } template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, "The group size should be larger than the columns"); static_assert( group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || bits == 8, "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = group_size / BCOLS; const int src_ld; const int tile_stride; short group_step_cnt; const int group_stride; const short thread_idx; const short bi; const short bj; threadgroup T* dst; const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED * bytes_per_pack : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, bias, dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { if (group_steps > 1) { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; scales++; biases++; } } else { scales++; biases++; } } else { scales += group_stride; biases += group_stride; } } }; template METAL_FUNC void qmv_quad_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int pack_factor = 32 / bits; constexpr int values_per_thread = D / QUAD_SIZE; constexpr int packs_per_thread = values_per_thread / pack_factor; constexpr int scale_step_per_thread = group_size / values_per_thread; constexpr int results_per_quadgroup = 8; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; w += out_row * in_vec_size_w + quad_lid * packs_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; x += tid.x * in_vec_size + quad_lid * values_per_thread; y += tid.x * out_vec_size + out_row; U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_quadgroup; row++) { auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device T* sl = scales + row * in_vec_size_g * quads_per_simd; const device T* bl = biases + row * in_vec_size_g * quads_per_simd; U s = sl[0]; U b = bl[0]; if (row * quads_per_simd + out_row < out_vec_size) { result[row] += qdot(wl, x_thread, s, b, sum); } } for (int row = 0; row < results_per_quadgroup; row++) { result[row] = quad_sum(result[row]); if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { y[row * quads_per_simd] = static_cast(result[row]); } } } template METAL_FUNC void qmv_fast_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } template METAL_FUNC void qmv_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); if (out_row >= out_vec_size) { return; } // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { U sum = load_vector_safe( x, x_thread, remaining); for (int row = 0; out_row + row < out_vec_size; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } } for (int row = 0; out_row + row < out_vec_size; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } // In this case the last tile is moved back to redo some output values else { ws += used_out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; } const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { U sum = load_vector_safe( x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot_safe( wl, x_thread, s, b, sum, remaining); } } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } } template METAL_FUNC void qvm_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, const int in_vec_size, const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; using W_T = typename ConditionalType::type; const device W_T* ws = (const device W_T*)w; typedef float U; typedef struct { W_T wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; thread U result[tn * pack_factor] = {0}; thread U scale = 1; thread U bias = 0; thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; x += tid.x * in_vec_size + simd_lid; y += tid.x * out_vec_size + out_col; if (out_col >= out_vec_size) { return; } // Loop over in_vec in blocks of block_size int remaining = in_vec_size % block_size; if (remaining == 0) { for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += block_size; scales += block_size * out_vec_size_g; biases += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } } else { for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += block_size; scales += block_size * out_vec_size_g; biases += block_size * out_vec_size_g; ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = *scales; bias = *biases; w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; bias = 0; } qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); } // Accumulate in the simdgroup #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { result[k] = simd_sum(result[k]); } // Store the result if (simd_lid == 0) { #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { y[k] = static_cast(result[k]); } } } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if (!aligned_N && num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM || num_outs < BN) { mma_op.store_result_safe(y, N, short2(num_outs, num_els)); } else { mma_op.store_result(y, N); } } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void qmm_n_impl( const device uint32_t* w, const device T* scales, const device T* biases, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: BlockLoader; using loader_w_t = QuantizedBlockLoader< T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>; auto wl = (const device uint8_t*)w; // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, num_els)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if ((K % BK) != 0) { const int k_blocks = K / BK; for (int k = 0; k < k_blocks; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } const short num_k = K - k_blocks * BK; threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(num_k, BM)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM) { mma_op.store_result_safe(y, N, short2(BN, num_els)); } else { mma_op.store_result(y, N); } } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, device T*& y, int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; uint32_t w_idx = tid.z; if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, const device T*& scales, const device T*& biases, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, int output_stride, const constant int& batch_ndims, const constant int* batch_shape, const constant int64_t* lhs_strides, const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, const constant int64_t* w_strides, const constant int64_t* s_strides, const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; uint32_t w_idx; if (batch_ndims == 1) { x_idx = lhs_indices[tid.z * lhs_strides[0]]; w_idx = rhs_indices[tid.z * rhs_strides[0]]; } else { ulong2 idx = elem_to_loc_broadcast( tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); x_idx = lhs_indices[idx.x]; w_idx = rhs_indices[idx.y]; } if (x_batch_ndims == 1) { x += x_idx * x_strides[0]; } else { x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); } if (w_batch_ndims == 1) { w += w_idx * w_strides[0]; scales += w_idx * s_strides[0]; biases += w_idx * b_strides[0]; } else { ulong3 idx = elem_to_loc_broadcast( w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); w += idx.x; scales += idx.y; biases += idx.z; } y += tid.z * output_stride; } template [[kernel]] void qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_quad_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); } template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_fast_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmv_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qvm_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], const constant int64_t* w_strides [[buffer(12)]], const constant int64_t* s_strides [[buffer(13)]], const constant int64_t* b_strides [[buffer(14)]], const constant int& final_block_size [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, y, out_vec_size * M, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); // When (in_vec_size % split_k != 0) the final block needs to be smaller int in_vec_size_adj = tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; qvm_impl( w, scales, biases, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], const constant int64_t* w_strides [[buffer(13)]], const constant int64_t* s_strides [[buffer(14)]], const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( x, w, scales, biases, y, M * N, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); } qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template [[kernel]] void gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmv_fast_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmv_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template [[kernel]] void gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& in_vec_size [[buffer(7)]], const constant int& out_vec_size [[buffer(8)]], const constant int& x_batch_ndims [[buffer(9)]], const constant int* x_shape [[buffer(10)]], const constant int64_t* x_strides [[buffer(11)]], const constant int& w_batch_ndims [[buffer(12)]], const constant int* w_shape [[buffer(13)]], const constant int64_t* w_strides [[buffer(14)]], const constant int64_t* s_strides [[buffer(15)]], const constant int64_t* b_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, out_vec_size * M, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qvm_impl( w, scales, biases, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], const device uint32_t* lhs_indices [[buffer(4)]], const device uint32_t* rhs_indices [[buffer(5)]], device T* y [[buffer(6)]], const constant int& K [[buffer(7)]], const constant int& N [[buffer(8)]], const constant int& M [[buffer(9)]], const constant int& x_batch_ndims [[buffer(10)]], const constant int* x_shape [[buffer(11)]], const constant int64_t* x_strides [[buffer(12)]], const constant int& w_batch_ndims [[buffer(13)]], const constant int* w_shape [[buffer(14)]], const constant int64_t* w_strides [[buffer(15)]], const constant int64_t* s_strides [[buffer(16)]], const constant int64_t* b_strides [[buffer(17)]], const constant int& batch_ndims [[buffer(18)]], const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; adjust_matrix_offsets( x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims, batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides, w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid); qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template < bool rows_aligned, bool cols_aligned, bool transpose, typename T, typename mma_t, typename loader_a_t, typename loader_b_t> METAL_FUNC void gemm_loop_unaligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations, const short tgp_bm, const short tgp_bn, const short tgp_bk) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory if (rows_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(short2(tgp_bk, tgp_bm)); } if (cols_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe( transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template METAL_FUNC void gemm_loop_finalize( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const short2 tile_a, const short2 tile_b) { loader_a.load_safe(tile_a); loader_b.load_safe(tile_b); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } template < typename T, int group_size, int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> [[kernel]] void gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], const device T* biases [[buffer(3)]], const device uint32_t* indices [[buffer(4)]], device T* y [[buffer(5)]], const constant int& M [[buffer(6)]], const constant int& N [[buffer(7)]], const constant int& K [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); using mma_t = mlx::steel::BlockMMA< T, T, BM, BN, BK, WM, WN, false, transpose, BK_padded, transpose ? BK_padded : BN_padded>; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = QuantizedBlockLoader< T, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, group_size, bits>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int N_w = N * bytes_per_pack / pack_factor; const int N_g = N / group_size; const int K_it = K / BK; const size_t stride_w = transpose ? N * K_w : K * N_w; const size_t stride_s = transpose ? N * K_g : K * N_g; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const size_t y_row_long = size_t(y_row); const size_t y_col_long = size_t(y_col); // Prepare threadgroup bounds const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); // Calculate the final tiles in the case that K is not aligned const int k_remain = K - K_it * BK; const short2 tile_x = short2(k_remain, tgp_bm); const short2 tile_w = transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); // Move x and output to the correct block auto wl = (const device uint8_t*)w; x += y_row_long * K; y += y_row_long * N + y_col_long; wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; biases += transpose ? y_col_long * K_g : y_col / group_size; // Do as many matmuls as necessary uint32_t index; short offset; uint32_t index_next = indices[y_row]; short offset_next = 0; int n = 0; while (n < tgp_bm) { n++; offset = offset_next; index = index_next; offset_next = tgp_bm; for (; n < tgp_bm; n++) { if (indices[y_row + n] != index) { offset_next = n; index_next = indices[y_row + n]; break; } } threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); // Prepare threadgroup loading operations thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); thread loader_w_t loader_w( wl + index * stride_w, scales + index * stride_s, biases + index * stride_s, transpose ? K : N, Ws, simd_group_id, simd_lane_id); // Matrices are all aligned check nothing if (align_M && align_N) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } else { // Tile aligned so check outside of the hot loop if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } // Store results to device memory if (offset_next - offset == BM) { mma_op.store_result(y, N); } else { mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } } // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(BN, offset_next)); } // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } // Nothing aligned so check both rows and cols else { gemm_loop_unaligned( Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); if (!align_K) { threadgroup_barrier(mem_flags::mem_threadgroup); gemm_loop_finalize( Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); } mma_op.store_result_slice( y, N, short2(0, offset), short2(tgp_bn, offset_next)); } } } } template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], device uint8_t* out [[buffer(1)]], device T* scales [[buffer(2)]], device T* biases [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; static_assert( group_size % simd_size == 0, "Group size must be divisible by simd size."); size_t offset = index.x + grid_dim.x * size_t(index.y); size_t in_index = offset * values_per_reduce; size_t out_index = power_of_2_bits ? offset * writes_per_pack : offset * bytes_per_pack / writes_per_reduce; float w_thread[values_per_reduce]; float w_min = Limits::max; float w_max = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { float val = w[in_index + i]; w_thread[i] = val; w_min = min(w_min, val); w_max = max(w_max, val); } w_min = simd_min(w_min); w_max = simd_max(w_max); float scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); scale = side ? scale : -scale; float edge = side ? w_min : w_max; float q0 = round(edge / scale); bool at_zero = q0 == 0.0f; scale = at_zero ? scale : edge / q0; float bias = at_zero ? 0 : edge; // Write out the scales and biases size_t gindex = in_index / group_size; if (in_index % group_size == 0) { scales[gindex] = static_cast(scale); biases[gindex] = static_cast(bias); } using OutType = metal::conditional_t; OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); if (bits == 8) { output = val; } else { output |= val << (bits * (i % pack_factor)); } if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); output |= static_cast(sval) << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } } else if (bits == 5) { if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; out[out_index + 3] = (output & 0xff000000) >> 24; out[out_index + 4] = (output & 0xff00000000) >> 32; } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; } } } template [[kernel]] void affine_dequantize( const device uint8_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr int pack_factor = get_pack_factor(); constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; out += oindex; if (bits == 3) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x7) * scale + bias; out[1] = ((w[0] & 0x38) >> 3) * scale + bias; out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; out[3] = ((w[1] & 0xe) >> 1) * scale + bias; out[4] = ((w[1] & 0x70) >> 4) * scale + bias; out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } else if (bits == 5) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x1f) * scale + bias; out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; } else { uint val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; } else if (bits == 4) { d = (val >> (bits * i)) & 0x0f; } else if (bits == 8) { d = val; } out[i] = scale * d + bias; } } }