mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
2589 lines
79 KiB
C++
2589 lines
79 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <metal_simdgroup>
|
|
#include <metal_stdlib>
|
|
|
|
constant bool align_M [[function_constant(200)]];
|
|
constant bool align_N [[function_constant(201)]];
|
|
constant bool align_K [[function_constant(202)]];
|
|
|
|
using namespace metal;
|
|
|
|
#define MLX_MTL_CONST static constant constexpr const
|
|
|
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
|
MLX_MTL_CONST int QUAD_SIZE = 4;
|
|
|
|
template <int bits, int wsize = 8>
|
|
inline constexpr short get_pack_factor() {
|
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
}
|
|
|
|
template <int bits, int wsize = 8>
|
|
inline constexpr short get_bytes_per_pack() {
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
}
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
U sum = 0;
|
|
|
|
if (bits == 2) {
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 4.0f;
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < values_per_thread; i += 8) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
x[i + 6] + x[i + 7];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 16.0f;
|
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < values_per_thread; i += 8) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
x[i + 6] + x[i + 7];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < values_per_thread; i++) {
|
|
sum += x[i];
|
|
x_thread[i] = x[i];
|
|
}
|
|
}
|
|
|
|
return sum;
|
|
}
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
U sum = 0;
|
|
|
|
if (bits == 2) {
|
|
for (int i = 0; i < N; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 4.0f;
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < N; i += 8) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
x[i + 6] + x[i + 7];
|
|
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
for (int i = 0; i < N; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 16.0f;
|
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < N; i += 8) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
x[i + 6] + x[i + 7];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < N; i += 4) {
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
x_thread[i] = x[i];
|
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < N; i++) {
|
|
sum += x[i];
|
|
x_thread[i] = x[i];
|
|
}
|
|
}
|
|
|
|
for (int i = N; i < values_per_thread; i++) {
|
|
x_thread[i] = 0;
|
|
}
|
|
|
|
return sum;
|
|
}
|
|
|
|
template <typename U, int values_per_thread, int bits>
|
|
inline U qdot(
|
|
const device uint8_t* w,
|
|
const thread U* x_thread,
|
|
U scale,
|
|
U bias,
|
|
U sum) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
U accum = 0;
|
|
|
|
if (bits == 2) {
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
accum +=
|
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
x_thread += 8 * i;
|
|
w += 3 * i;
|
|
|
|
accum += (w[0] & 0x07) * x_thread[0];
|
|
accum += (w[0] & 0x38) * x_thread[1];
|
|
accum += (w[0] & 0xc0) * x_thread[2];
|
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
|
|
accum += (w[1] & 0x0e) * x_thread[3];
|
|
accum += (w[1] & 0x70) * x_thread[4];
|
|
accum += (w[1] & 0x80) * x_thread[5];
|
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
|
|
accum += (w[2] & 0x1c) * x_thread[6];
|
|
accum += (w[2] & 0xe0) * x_thread[7];
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
const device uint16_t* ws = (const device uint16_t*)w;
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
accum +=
|
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
x_thread += 8 * i;
|
|
w += 5 * i;
|
|
|
|
accum += (w[0] & 0x1f) * x_thread[0];
|
|
accum += (w[0] & 0xe0) * x_thread[1];
|
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
accum += (w[1] & 0x7c) * x_thread[2];
|
|
accum += (w[1] & 0x80) * x_thread[3];
|
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
accum += (w[2] & 0xf0) * x_thread[4];
|
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
accum += (w[3] & 0x3e) * x_thread[5];
|
|
accum += (w[3] & 0xc0) * x_thread[6];
|
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
accum += (w[4] & 0xf8) * x_thread[7];
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
x_thread += 4 * i;
|
|
w += 3 * i;
|
|
|
|
accum += (w[0] & 0x3f) * x_thread[0];
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[1];
|
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
|
|
accum += (w[1] & 0xf0) * x_thread[2];
|
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
|
|
accum += (w[2] & 0xfc) * x_thread[3];
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < values_per_thread; i++) {
|
|
accum += x_thread[i] * w[i];
|
|
}
|
|
}
|
|
|
|
return scale * accum + sum * bias;
|
|
}
|
|
|
|
template <typename U, int values_per_thread, int bits>
|
|
inline U qdot_safe(
|
|
const device uint8_t* w,
|
|
const thread U* x_thread,
|
|
U scale,
|
|
U bias,
|
|
U sum,
|
|
int N) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
U accum = 0;
|
|
|
|
if (bits == 2) {
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
accum +=
|
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
x_thread += 8 * i;
|
|
w += 3 * i;
|
|
|
|
accum += (w[0] & 0x07) * x_thread[0];
|
|
accum += (w[0] & 0x38) * x_thread[1];
|
|
accum += (w[0] & 0xc0) * x_thread[2];
|
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
|
|
accum += (w[1] & 0x0e) * x_thread[3];
|
|
accum += (w[1] & 0x70) * x_thread[4];
|
|
accum += (w[1] & 0x80) * x_thread[5];
|
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
|
|
accum += (w[2] & 0x1c) * x_thread[6];
|
|
accum += (w[2] & 0xe0) * x_thread[7];
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
const device uint16_t* ws = (const device uint16_t*)w;
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
accum +=
|
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
x_thread += 8 * i;
|
|
w += 5 * i;
|
|
|
|
accum += (w[0] & 0x1f) * x_thread[0];
|
|
accum += (w[0] & 0xe0) * x_thread[1];
|
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
accum += (w[1] & 0x7c) * x_thread[2];
|
|
accum += (w[1] & 0x80) * x_thread[3];
|
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
accum += (w[2] & 0xf0) * x_thread[4];
|
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
accum += (w[3] & 0x3e) * x_thread[5];
|
|
accum += (w[3] & 0xc0) * x_thread[6];
|
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
accum += (w[4] & 0xf8) * x_thread[7];
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
x_thread += 4 * i;
|
|
w += 3 * i;
|
|
|
|
accum += (w[0] & 0x3f) * x_thread[0];
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[1];
|
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
|
|
accum += (w[1] & 0xf0) * x_thread[2];
|
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
|
|
accum += (w[2] & 0xfc) * x_thread[3];
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < N; i++) {
|
|
accum += x_thread[i] * w[i];
|
|
}
|
|
}
|
|
|
|
return scale * accum + sum * bias;
|
|
}
|
|
|
|
template <typename U, int values_per_thread, int bits>
|
|
inline void
|
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
if (bits == 2) {
|
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
|
|
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
|
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
|
|
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
uint8_t w0 = w[3 * i];
|
|
uint8_t w1 = w[3 * i + 1];
|
|
uint8_t w2 = w[3 * i + 2];
|
|
|
|
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
|
|
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
|
|
result[8 * i + 2] +=
|
|
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
|
|
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
|
|
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
|
|
result[8 * i + 5] +=
|
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
|
|
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
|
|
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
U s[2] = {scale, scale / 16.0f};
|
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
uint8_t w0 = w[5 * i];
|
|
uint8_t w1 = w[5 * i + 1];
|
|
uint8_t w2 = w[5 * i + 2];
|
|
uint8_t w3 = w[5 * i + 3];
|
|
uint8_t w4 = w[5 * i + 4];
|
|
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
|
|
result[8 * i + 1] +=
|
|
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
|
|
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
|
|
result[8 * i + 3] +=
|
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
|
|
result[8 * i + 4] +=
|
|
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
|
|
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
|
|
result[8 * i + 6] +=
|
|
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
|
|
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
uint8_t w0 = w[3 * i];
|
|
uint8_t w1 = w[3 * i + 1];
|
|
uint8_t w2 = w[3 * i + 2];
|
|
|
|
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
|
|
result[4 * i + 1] +=
|
|
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
|
|
result[4 * i + 2] +=
|
|
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
|
|
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < values_per_thread; i++) {
|
|
result[i] += x * (scale * w[i] + bias);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int N, int bits>
|
|
inline void
|
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
if (bits == 2) {
|
|
U s[4] = {
|
|
scale,
|
|
scale / static_cast<U>(4.0f),
|
|
scale / static_cast<U>(16.0f),
|
|
scale / static_cast<U>(64.0f)};
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
|
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
|
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
|
|
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
|
|
}
|
|
}
|
|
|
|
else if (bits == 3) {
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
w_local += 8 * i;
|
|
w += 3 * i;
|
|
|
|
w_local[0] = (w[0] & 0x7) * scale + bias;
|
|
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
}
|
|
}
|
|
|
|
else if (bits == 4) {
|
|
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
|
for (int i = 0; i < (N / 2); i++) {
|
|
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
|
|
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
|
|
}
|
|
}
|
|
|
|
else if (bits == 5) {
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
w_local += 8 * i;
|
|
w += 5 * i;
|
|
|
|
w_local[0] = (w[0] & 0x1f) * scale + bias;
|
|
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
}
|
|
}
|
|
|
|
else if (bits == 6) {
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
w_local += 4 * i;
|
|
w += 3 * i;
|
|
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
|
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
}
|
|
}
|
|
|
|
else if (bits == 8) {
|
|
for (int i = 0; i < N; i++) {
|
|
w_local[i] = scale * w[i] + bias;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
short BROWS,
|
|
short BCOLS,
|
|
short dst_ld,
|
|
short reduction_dim,
|
|
short tgp_size,
|
|
short group_size,
|
|
short bits>
|
|
struct QuantizedBlockLoader {
|
|
static_assert(
|
|
BCOLS <= group_size,
|
|
"The group size should be larger than the columns");
|
|
static_assert(
|
|
group_size % BCOLS == 0,
|
|
"The group size should be divisible by the columns");
|
|
static_assert(
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
bits == 8,
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
|
|
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
|
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
|
MLX_MTL_CONST short n_reads =
|
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
|
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
|
|
|
const int src_ld;
|
|
const int tile_stride;
|
|
short group_step_cnt;
|
|
const int group_stride;
|
|
|
|
const short thread_idx;
|
|
const short bi;
|
|
const short bj;
|
|
|
|
threadgroup T* dst;
|
|
const device uint8_t* src;
|
|
const device T* scales;
|
|
const device T* biases;
|
|
|
|
QuantizedBlockLoader(
|
|
const device uint8_t* src_,
|
|
const device T* scales_,
|
|
const device T* biases_,
|
|
const int src_ld_,
|
|
threadgroup T* dst_,
|
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
: src_ld(src_ld_),
|
|
tile_stride(
|
|
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
|
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
|
group_step_cnt(0),
|
|
group_stride(BROWS * src_ld / group_size),
|
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
bi(n_reads * thread_idx / BCOLS_PACKED),
|
|
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
|
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
|
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
|
bj * bytes_per_pack),
|
|
scales(scales_ + bi * src_ld / group_size),
|
|
biases(biases_ + bi * src_ld / group_size) {}
|
|
|
|
void load_unsafe() const {
|
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
return;
|
|
}
|
|
|
|
T scale = *scales;
|
|
T bias = *biases;
|
|
for (int i = 0; i < n_reads; i++) {
|
|
dequantize<T, pack_factor, bits>(
|
|
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
|
|
}
|
|
}
|
|
|
|
void load_safe(short2 src_tile_dim) const {
|
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
return;
|
|
}
|
|
|
|
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
|
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
dst[i] = T(0);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
|
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
dst[i] = T(0);
|
|
}
|
|
return;
|
|
}
|
|
|
|
T scale = *scales;
|
|
T bias = *biases;
|
|
for (int i = 0; i < n_reads; i++) {
|
|
dequantize<T, pack_factor, bits>(
|
|
(device uint8_t*)(src + i * bytes_per_pack),
|
|
scale,
|
|
bias,
|
|
dst + i * pack_factor);
|
|
}
|
|
}
|
|
|
|
void next() {
|
|
src += tile_stride;
|
|
if (reduction_dim == 1) {
|
|
if (group_steps > 1) {
|
|
group_step_cnt++;
|
|
if (group_step_cnt == group_steps) {
|
|
group_step_cnt = 0;
|
|
scales++;
|
|
biases++;
|
|
}
|
|
} else {
|
|
scales++;
|
|
biases++;
|
|
}
|
|
} else {
|
|
scales += group_stride;
|
|
biases += group_stride;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename T, int group_size, int bits, int D>
|
|
METAL_FUNC void qmv_quad_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
constant int& in_vec_size,
|
|
const constant int& out_vec_size,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
|
constexpr int pack_factor = 32 / bits;
|
|
constexpr int values_per_thread = D / QUAD_SIZE;
|
|
constexpr int packs_per_thread = values_per_thread / pack_factor;
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
constexpr int results_per_quadgroup = 8;
|
|
|
|
typedef float U;
|
|
|
|
thread U x_thread[values_per_thread];
|
|
thread U result[results_per_quadgroup] = {0};
|
|
|
|
// Adjust positions
|
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
|
const int in_vec_size_g = in_vec_size / group_size;
|
|
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
|
|
|
|
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
|
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
|
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
|
x += tid.x * in_vec_size + quad_lid * values_per_thread;
|
|
y += tid.x * out_vec_size + out_row;
|
|
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
for (int row = 0; row < results_per_quadgroup; row++) {
|
|
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
|
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
|
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
if (row * quads_per_simd + out_row < out_vec_size) {
|
|
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
}
|
|
}
|
|
|
|
for (int row = 0; row < results_per_quadgroup; row++) {
|
|
result[row] = quad_sum(result[row]);
|
|
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
|
|
y[row * quads_per_simd] = static_cast<T>(result[row]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int group_size, int bits>
|
|
METAL_FUNC void qmv_fast_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
const constant int& in_vec_size,
|
|
const constant int& out_vec_size,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
|
constexpr int num_simdgroups = 2;
|
|
constexpr int results_per_simdgroup = 4;
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
const device uint8_t* ws = (const device uint8_t*)w;
|
|
|
|
typedef float U;
|
|
|
|
thread U x_thread[values_per_thread];
|
|
thread U result[results_per_simdgroup] = {0};
|
|
|
|
// Adjust positions
|
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
const int in_vec_size_g = in_vec_size / group_size;
|
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
|
simd_gid * results_per_simdgroup;
|
|
|
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
y += tid.x * out_vec_size + out_row;
|
|
|
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
}
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
scales += block_size / group_size;
|
|
biases += block_size / group_size;
|
|
x += block_size;
|
|
}
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
result[row] = simd_sum(result[row]);
|
|
if (simd_lid == 0) {
|
|
y[row] = static_cast<T>(result[row]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int group_size, int bits>
|
|
METAL_FUNC void qmv_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
const constant int& in_vec_size,
|
|
const constant int& out_vec_size,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int num_simdgroups = 2;
|
|
constexpr int results_per_simdgroup = 4;
|
|
constexpr int packs_per_thread = 1;
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
const device uint8_t* ws = (const device uint8_t*)w;
|
|
|
|
typedef float U;
|
|
|
|
thread U x_thread[values_per_thread];
|
|
thread U result[results_per_simdgroup] = {0};
|
|
|
|
// Adjust positions
|
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
const int in_vec_size_g = in_vec_size / group_size;
|
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
|
simd_gid * results_per_simdgroup;
|
|
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
|
|
|
if (out_row >= out_vec_size) {
|
|
return;
|
|
}
|
|
|
|
// In this case we need to properly guard all our reads because there isn't
|
|
// even 1 tile in the matrix
|
|
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
|
ws +=
|
|
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
y += tid.x * out_vec_size + out_row;
|
|
|
|
int k = 0;
|
|
for (; k < in_vec_size - block_size; k += block_size) {
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
result[row] +=
|
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
}
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
scales += block_size / group_size;
|
|
biases += block_size / group_size;
|
|
x += block_size;
|
|
}
|
|
const int remaining = clamp(
|
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
|
0,
|
|
values_per_thread);
|
|
if (remaining > 0) {
|
|
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
|
x, x_thread, remaining);
|
|
|
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
result[row] +=
|
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
}
|
|
}
|
|
|
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
result[row] = simd_sum(result[row]);
|
|
if (simd_lid == 0) {
|
|
y[row] = static_cast<T>(result[row]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// In this case the last tile is moved back to redo some output values
|
|
else {
|
|
ws += used_out_row * in_vec_size_w +
|
|
simd_lid * packs_per_thread * bytes_per_pack;
|
|
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
y += tid.x * out_vec_size + used_out_row;
|
|
|
|
int k = 0;
|
|
for (; k < in_vec_size - block_size; k += block_size) {
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
result[row] +=
|
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
}
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
scales += block_size / group_size;
|
|
biases += block_size / group_size;
|
|
x += block_size;
|
|
}
|
|
const int remaining = clamp(
|
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
|
0,
|
|
values_per_thread);
|
|
if (remaining > 0) {
|
|
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
|
x, x_thread, remaining);
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
U s = sl[0];
|
|
U b = bl[0];
|
|
result[row] += qdot_safe<U, values_per_thread, bits>(
|
|
wl, x_thread, s, b, sum, remaining);
|
|
}
|
|
}
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
result[row] = simd_sum(result[row]);
|
|
if (simd_lid == 0) {
|
|
y[row] = static_cast<T>(result[row]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits>
|
|
METAL_FUNC void qvm_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
const int in_vec_size,
|
|
const int out_vec_size,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
constexpr int num_simdgroups = 2;
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
constexpr int tn = 32 / pack_factor;
|
|
constexpr int block_size = SIMD_SIZE;
|
|
|
|
using W_T =
|
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
|
const device W_T* ws = (const device W_T*)w;
|
|
|
|
typedef float U;
|
|
typedef struct {
|
|
W_T wi[tn * bytes_per_pack];
|
|
} vec_w;
|
|
|
|
thread vec_w w_local;
|
|
thread U result[tn * pack_factor] = {0};
|
|
thread U scale = 1;
|
|
thread U bias = 0;
|
|
thread U x_local = 0;
|
|
|
|
// Adjust positions
|
|
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
|
const int out_vec_size_g = out_vec_size / group_size;
|
|
int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
|
|
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
|
|
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
|
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
|
x += tid.x * in_vec_size + simd_lid;
|
|
y += tid.x * out_vec_size + out_col;
|
|
|
|
if (out_col >= out_vec_size) {
|
|
return;
|
|
}
|
|
|
|
// Loop over in_vec in blocks of block_size
|
|
int remaining = in_vec_size % block_size;
|
|
if (remaining == 0) {
|
|
for (int i = 0; i < in_vec_size; i += block_size) {
|
|
x_local = *x;
|
|
scale = *scales;
|
|
bias = *biases;
|
|
w_local = *((device vec_w*)ws);
|
|
qouter<U, tn * pack_factor, bits>(
|
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
|
|
x += block_size;
|
|
scales += block_size * out_vec_size_g;
|
|
biases += block_size * out_vec_size_g;
|
|
ws += block_size * out_vec_size_w;
|
|
}
|
|
} else {
|
|
for (int i = block_size; i < in_vec_size; i += block_size) {
|
|
x_local = *x;
|
|
scale = *scales;
|
|
bias = *biases;
|
|
w_local = *((device vec_w*)ws);
|
|
|
|
qouter<U, tn * pack_factor, bits>(
|
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
|
|
x += block_size;
|
|
scales += block_size * out_vec_size_g;
|
|
biases += block_size * out_vec_size_g;
|
|
ws += block_size * out_vec_size_w;
|
|
}
|
|
if (static_cast<int>(simd_lid) < remaining) {
|
|
x_local = *x;
|
|
scale = *scales;
|
|
bias = *biases;
|
|
w_local = *((device vec_w*)ws);
|
|
} else {
|
|
x_local = 0;
|
|
scale = 0;
|
|
bias = 0;
|
|
}
|
|
qouter<U, tn * pack_factor, bits>(
|
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
}
|
|
|
|
// Accumulate in the simdgroup
|
|
#pragma clang loop unroll(full)
|
|
for (int k = 0; k < tn * pack_factor; k++) {
|
|
result[k] = simd_sum(result[k]);
|
|
}
|
|
|
|
// Store the result
|
|
if (simd_lid == 0) {
|
|
#pragma clang loop unroll(full)
|
|
for (int k = 0; k < tn * pack_factor; k++) {
|
|
y[k] = static_cast<T>(result[k]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const bool aligned_N,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
METAL_FUNC void qmm_t_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
threadgroup T* Xs,
|
|
threadgroup T* Ws,
|
|
const constant int& K,
|
|
const constant int& N,
|
|
const constant int& M,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
|
|
(void)lid;
|
|
|
|
constexpr int WM = 2;
|
|
constexpr int WN = 2;
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
using mma_t = mlx::steel::
|
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
|
using loader_x_t =
|
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
|
using loader_w_t = QuantizedBlockLoader<
|
|
T,
|
|
BN,
|
|
BK,
|
|
BK_padded,
|
|
1,
|
|
WM * WN * SIMD_SIZE,
|
|
group_size,
|
|
bits>;
|
|
|
|
// Set the block
|
|
const int K_w = K * bytes_per_pack / pack_factor;
|
|
const int K_g = K / group_size;
|
|
const int y_row = tid.y * BM;
|
|
const int y_col = tid.x * BN;
|
|
|
|
auto wl = (const device uint8_t*)w;
|
|
|
|
x += y_row * static_cast<int64_t>(K);
|
|
wl += y_col * K_w;
|
|
scales += y_col * K_g;
|
|
biases += y_col * K_g;
|
|
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
|
|
// Make the x loader and mma operation
|
|
const short num_els = min(BM, M - y_row);
|
|
const short num_outs = min(BN, N - y_col);
|
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
|
|
mma_t mma_op(simd_gid, simd_lid);
|
|
|
|
if (num_els < BM) {
|
|
if (!aligned_N && num_outs < BN) {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(BK, num_els));
|
|
loader_w.load_safe(short2(BK, num_outs));
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
} else {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(BK, num_els));
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
}
|
|
} else {
|
|
if (!aligned_N && num_outs < BN) {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_unsafe();
|
|
loader_w.load_safe(short2(BK, num_outs));
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
} else {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_unsafe();
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Store results to device memory
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (num_els < BM || num_outs < BN) {
|
|
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
|
} else {
|
|
mma_op.store_result(y, N);
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
METAL_FUNC void qmm_n_impl(
|
|
const device uint32_t* w,
|
|
const device T* scales,
|
|
const device T* biases,
|
|
const device T* x,
|
|
device T* y,
|
|
threadgroup T* Xs,
|
|
threadgroup T* Ws,
|
|
const constant int& K,
|
|
const constant int& N,
|
|
const constant int& M,
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
|
|
(void)lid;
|
|
|
|
constexpr int WM = 2;
|
|
constexpr int WN = 2;
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
using mma_t = mlx::steel::
|
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
|
using loader_x_t = mlx::steel::
|
|
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
|
using loader_w_t = QuantizedBlockLoader<
|
|
T,
|
|
BK,
|
|
BN,
|
|
BN_padded,
|
|
0,
|
|
WM * WN * SIMD_SIZE,
|
|
group_size,
|
|
bits>;
|
|
|
|
auto wl = (const device uint8_t*)w;
|
|
|
|
// Set the block
|
|
const int y_row = tid.y * BM;
|
|
const int y_col = tid.x * BN;
|
|
x += y_row * static_cast<int64_t>(K);
|
|
wl += y_col * bytes_per_pack / pack_factor;
|
|
scales += y_col / group_size;
|
|
biases += y_col / group_size;
|
|
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
|
|
// Make the x loader and mma operation
|
|
const short num_els = min(BM, M - y_row);
|
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
|
|
mma_t mma_op(simd_gid, simd_lid);
|
|
|
|
if (num_els < BM) {
|
|
if ((K % BK) != 0) {
|
|
const int k_blocks = K / BK;
|
|
for (int k = 0; k < k_blocks; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(BK, num_els));
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
const short num_k = K - k_blocks * BK;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(num_k, num_els));
|
|
loader_w.load_safe(short2(BN, num_k));
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
} else {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(BK, num_els));
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
}
|
|
} else {
|
|
if ((K % BK) != 0) {
|
|
const int k_blocks = K / BK;
|
|
for (int k = 0; k < k_blocks; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_unsafe();
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
const short num_k = K - k_blocks * BK;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_safe(short2(num_k, BM));
|
|
loader_w.load_safe(short2(BN, num_k));
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
} else {
|
|
for (int k = 0; k < K; k += BK) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
loader_x.load_unsafe();
|
|
loader_w.load_unsafe();
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(Xs, Ws);
|
|
loader_x.next();
|
|
loader_w.next();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Store results to device memory
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (num_els < BM) {
|
|
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
|
} else {
|
|
mma_op.store_result(y, N);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
METAL_FUNC void adjust_matrix_offsets(
|
|
const device T*& x,
|
|
const device uint32_t*& w,
|
|
const device T*& scales,
|
|
const device T*& biases,
|
|
device T*& y,
|
|
int output_stride,
|
|
const constant int& x_batch_ndims,
|
|
const constant int* x_shape,
|
|
const constant int64_t* x_strides,
|
|
const constant int& w_batch_ndims,
|
|
const constant int* w_shape,
|
|
const constant int64_t* w_strides,
|
|
const constant int64_t* s_strides,
|
|
const constant int64_t* b_strides,
|
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
// Set the input/output matrices
|
|
uint32_t x_idx = tid.z;
|
|
uint32_t w_idx = tid.z;
|
|
if (x_batch_ndims == 1) {
|
|
x += x_idx * x_strides[0];
|
|
} else {
|
|
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
}
|
|
if (w_batch_ndims == 1) {
|
|
w += w_idx * w_strides[0];
|
|
scales += w_idx * s_strides[0];
|
|
biases += w_idx * b_strides[0];
|
|
} else {
|
|
ulong3 idx = elem_to_loc_broadcast(
|
|
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
|
w += idx.x;
|
|
scales += idx.y;
|
|
biases += idx.z;
|
|
}
|
|
y += tid.z * output_stride;
|
|
}
|
|
|
|
template <typename T>
|
|
METAL_FUNC void adjust_matrix_offsets(
|
|
const device T*& x,
|
|
const device uint32_t*& w,
|
|
const device T*& scales,
|
|
const device T*& biases,
|
|
const device uint32_t* lhs_indices,
|
|
const device uint32_t* rhs_indices,
|
|
device T*& y,
|
|
int output_stride,
|
|
const constant int& batch_ndims,
|
|
const constant int* batch_shape,
|
|
const constant int64_t* lhs_strides,
|
|
const constant int64_t* rhs_strides,
|
|
const constant int& x_batch_ndims,
|
|
const constant int* x_shape,
|
|
const constant int64_t* x_strides,
|
|
const constant int& w_batch_ndims,
|
|
const constant int* w_shape,
|
|
const constant int64_t* w_strides,
|
|
const constant int64_t* s_strides,
|
|
const constant int64_t* b_strides,
|
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
// Set the input/output matrices
|
|
uint32_t x_idx;
|
|
uint32_t w_idx;
|
|
if (batch_ndims == 1) {
|
|
x_idx = lhs_indices[tid.z * lhs_strides[0]];
|
|
w_idx = rhs_indices[tid.z * rhs_strides[0]];
|
|
} else {
|
|
ulong2 idx = elem_to_loc_broadcast(
|
|
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
|
|
x_idx = lhs_indices[idx.x];
|
|
w_idx = rhs_indices[idx.y];
|
|
}
|
|
if (x_batch_ndims == 1) {
|
|
x += x_idx * x_strides[0];
|
|
} else {
|
|
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
}
|
|
if (w_batch_ndims == 1) {
|
|
w += w_idx * w_strides[0];
|
|
scales += w_idx * s_strides[0];
|
|
biases += w_idx * b_strides[0];
|
|
} else {
|
|
ulong3 idx = elem_to_loc_broadcast(
|
|
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
|
w += idx.x;
|
|
scales += idx.y;
|
|
biases += idx.z;
|
|
}
|
|
y += tid.z * output_stride;
|
|
}
|
|
|
|
template <typename T, int group_size, int bits, int D, bool batched>
|
|
[[kernel]] void qmv_quad(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& in_vec_size [[buffer(5)]],
|
|
const constant int& out_vec_size [[buffer(6)]],
|
|
const constant int& x_batch_ndims [[buffer(7)]],
|
|
const constant int* x_shape [[buffer(8)]],
|
|
const constant int64_t* x_strides [[buffer(9)]],
|
|
const constant int& w_batch_ndims [[buffer(10)]],
|
|
const constant int* w_shape [[buffer(11)]],
|
|
const constant int64_t* w_strides [[buffer(12)]],
|
|
const constant int64_t* s_strides [[buffer(13)]],
|
|
const constant int64_t* b_strides [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
if (batched) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
out_vec_size * M,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
qmv_quad_impl<T, group_size, bits, D>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
quad_gid,
|
|
quad_lid);
|
|
}
|
|
|
|
template <typename T, int group_size, int bits, bool batched>
|
|
[[kernel]] void qmv_fast(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& in_vec_size [[buffer(5)]],
|
|
const constant int& out_vec_size [[buffer(6)]],
|
|
const constant int& x_batch_ndims [[buffer(7)]],
|
|
const constant int* x_shape [[buffer(8)]],
|
|
const constant int64_t* x_strides [[buffer(9)]],
|
|
const constant int& w_batch_ndims [[buffer(10)]],
|
|
const constant int* w_shape [[buffer(11)]],
|
|
const constant int64_t* w_strides [[buffer(12)]],
|
|
const constant int64_t* s_strides [[buffer(13)]],
|
|
const constant int64_t* b_strides [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
if (batched) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
out_vec_size * M,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
qmv_fast_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits, bool batched>
|
|
[[kernel]] void qmv(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& in_vec_size [[buffer(5)]],
|
|
const constant int& out_vec_size [[buffer(6)]],
|
|
const constant int& x_batch_ndims [[buffer(7)]],
|
|
const constant int* x_shape [[buffer(8)]],
|
|
const constant int64_t* x_strides [[buffer(9)]],
|
|
const constant int& w_batch_ndims [[buffer(10)]],
|
|
const constant int* w_shape [[buffer(11)]],
|
|
const constant int64_t* w_strides [[buffer(12)]],
|
|
const constant int64_t* s_strides [[buffer(13)]],
|
|
const constant int64_t* b_strides [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
if (batched) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
out_vec_size * M,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
qmv_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits, bool batched>
|
|
[[kernel]] void qvm(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& in_vec_size [[buffer(5)]],
|
|
const constant int& out_vec_size [[buffer(6)]],
|
|
const constant int& x_batch_ndims [[buffer(7)]],
|
|
const constant int* x_shape [[buffer(8)]],
|
|
const constant int64_t* x_strides [[buffer(9)]],
|
|
const constant int& w_batch_ndims [[buffer(10)]],
|
|
const constant int* w_shape [[buffer(11)]],
|
|
const constant int64_t* w_strides [[buffer(12)]],
|
|
const constant int64_t* s_strides [[buffer(13)]],
|
|
const constant int64_t* b_strides [[buffer(14)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
if (batched) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
out_vec_size * M,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
qvm_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits, int split_k = 32>
|
|
[[kernel]] void qvm_split_k(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& in_vec_size [[buffer(5)]],
|
|
const constant int& out_vec_size [[buffer(6)]],
|
|
const constant int& x_batch_ndims [[buffer(7)]],
|
|
const constant int* x_shape [[buffer(8)]],
|
|
const constant int64_t* x_strides [[buffer(9)]],
|
|
const constant int& w_batch_ndims [[buffer(10)]],
|
|
const constant int* w_shape [[buffer(11)]],
|
|
const constant int64_t* w_strides [[buffer(12)]],
|
|
const constant int64_t* s_strides [[buffer(13)]],
|
|
const constant int64_t* b_strides [[buffer(14)]],
|
|
const constant int& final_block_size [[buffer(15)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
out_vec_size * M,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
|
|
// When (in_vec_size % split_k != 0) the final block needs to be smaller
|
|
int in_vec_size_adj =
|
|
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
|
|
|
qvm_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size_adj,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const bool aligned_N,
|
|
const bool batched,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
[[kernel]] void qmm_t(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& K [[buffer(5)]],
|
|
const constant int& N [[buffer(6)]],
|
|
const constant int& M [[buffer(7)]],
|
|
const constant int& x_batch_ndims [[buffer(8)]],
|
|
const constant int* x_shape [[buffer(9)]],
|
|
const constant int64_t* x_strides [[buffer(10)]],
|
|
const constant int& w_batch_ndims [[buffer(11)]],
|
|
const constant int* w_shape [[buffer(12)]],
|
|
const constant int64_t* w_strides [[buffer(13)]],
|
|
const constant int64_t* s_strides [[buffer(14)]],
|
|
const constant int64_t* b_strides [[buffer(15)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
threadgroup T Xs[BM * BK_padded];
|
|
threadgroup T Ws[BN * BK_padded];
|
|
|
|
if (batched) {
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
M * N,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const bool batched,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
[[kernel]] void qmm_n(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
device T* y [[buffer(4)]],
|
|
const constant int& K [[buffer(5)]],
|
|
const constant int& N [[buffer(6)]],
|
|
const constant int& M [[buffer(7)]],
|
|
const constant int& x_batch_ndims [[buffer(8)]],
|
|
const constant int* x_shape [[buffer(9)]],
|
|
const constant int64_t* x_strides [[buffer(10)]],
|
|
const constant int& w_batch_ndims [[buffer(11)]],
|
|
const constant int* w_shape [[buffer(12)]],
|
|
const constant int64_t* w_strides [[buffer(13)]],
|
|
const constant int64_t* s_strides [[buffer(14)]],
|
|
const constant int64_t* b_strides [[buffer(15)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
threadgroup T Xs[BM * BK_padded];
|
|
threadgroup T Ws[BK * BN_padded];
|
|
|
|
if (batched) {
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
y,
|
|
M * N,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
}
|
|
|
|
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
}
|
|
|
|
template <typename T, int group_size, int bits>
|
|
[[kernel]] void gather_qmv_fast(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
device T* y [[buffer(6)]],
|
|
const constant int& in_vec_size [[buffer(7)]],
|
|
const constant int& out_vec_size [[buffer(8)]],
|
|
const constant int& x_batch_ndims [[buffer(9)]],
|
|
const constant int* x_shape [[buffer(10)]],
|
|
const constant int64_t* x_strides [[buffer(11)]],
|
|
const constant int& w_batch_ndims [[buffer(12)]],
|
|
const constant int* w_shape [[buffer(13)]],
|
|
const constant int64_t* w_strides [[buffer(14)]],
|
|
const constant int64_t* s_strides [[buffer(15)]],
|
|
const constant int64_t* b_strides [[buffer(16)]],
|
|
const constant int& batch_ndims [[buffer(17)]],
|
|
const constant int* batch_shape [[buffer(18)]],
|
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
y,
|
|
out_vec_size * M,
|
|
batch_ndims,
|
|
batch_shape,
|
|
lhs_strides,
|
|
rhs_strides,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
qmv_fast_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <typename T, int group_size, int bits>
|
|
[[kernel]] void gather_qmv(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
device T* y [[buffer(6)]],
|
|
const constant int& in_vec_size [[buffer(7)]],
|
|
const constant int& out_vec_size [[buffer(8)]],
|
|
const constant int& x_batch_ndims [[buffer(9)]],
|
|
const constant int* x_shape [[buffer(10)]],
|
|
const constant int64_t* x_strides [[buffer(11)]],
|
|
const constant int& w_batch_ndims [[buffer(12)]],
|
|
const constant int* w_shape [[buffer(13)]],
|
|
const constant int64_t* w_strides [[buffer(14)]],
|
|
const constant int64_t* s_strides [[buffer(15)]],
|
|
const constant int64_t* b_strides [[buffer(16)]],
|
|
const constant int& batch_ndims [[buffer(17)]],
|
|
const constant int* batch_shape [[buffer(18)]],
|
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
y,
|
|
out_vec_size * M,
|
|
batch_ndims,
|
|
batch_shape,
|
|
lhs_strides,
|
|
rhs_strides,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
qmv_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <typename T, int group_size, int bits>
|
|
[[kernel]] void gather_qvm(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
device T* y [[buffer(6)]],
|
|
const constant int& in_vec_size [[buffer(7)]],
|
|
const constant int& out_vec_size [[buffer(8)]],
|
|
const constant int& x_batch_ndims [[buffer(9)]],
|
|
const constant int* x_shape [[buffer(10)]],
|
|
const constant int64_t* x_strides [[buffer(11)]],
|
|
const constant int& w_batch_ndims [[buffer(12)]],
|
|
const constant int* w_shape [[buffer(13)]],
|
|
const constant int64_t* w_strides [[buffer(14)]],
|
|
const constant int64_t* s_strides [[buffer(15)]],
|
|
const constant int64_t* b_strides [[buffer(16)]],
|
|
const constant int& batch_ndims [[buffer(17)]],
|
|
const constant int* batch_shape [[buffer(18)]],
|
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
int M = x_shape[x_batch_ndims];
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
y,
|
|
out_vec_size * M,
|
|
batch_ndims,
|
|
batch_shape,
|
|
lhs_strides,
|
|
rhs_strides,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
qvm_impl<T, group_size, bits>(
|
|
w,
|
|
scales,
|
|
biases,
|
|
x,
|
|
y,
|
|
in_vec_size,
|
|
out_vec_size,
|
|
tid,
|
|
simd_gid,
|
|
simd_lid);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const bool aligned_N,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
[[kernel]] void gather_qmm_t(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
device T* y [[buffer(6)]],
|
|
const constant int& K [[buffer(7)]],
|
|
const constant int& N [[buffer(8)]],
|
|
const constant int& M [[buffer(9)]],
|
|
const constant int& x_batch_ndims [[buffer(10)]],
|
|
const constant int* x_shape [[buffer(11)]],
|
|
const constant int64_t* x_strides [[buffer(12)]],
|
|
const constant int& w_batch_ndims [[buffer(13)]],
|
|
const constant int* w_shape [[buffer(14)]],
|
|
const constant int64_t* w_strides [[buffer(15)]],
|
|
const constant int64_t* s_strides [[buffer(16)]],
|
|
const constant int64_t* b_strides [[buffer(17)]],
|
|
const constant int& batch_ndims [[buffer(18)]],
|
|
const constant int* batch_shape [[buffer(19)]],
|
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
threadgroup T Xs[BM * BK_padded];
|
|
threadgroup T Ws[BN * BK_padded];
|
|
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
y,
|
|
M * N,
|
|
batch_ndims,
|
|
batch_shape,
|
|
lhs_strides,
|
|
rhs_strides,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
const int group_size,
|
|
const int bits,
|
|
const int BM = 32,
|
|
const int BK = 32,
|
|
const int BN = 32>
|
|
[[kernel]] void gather_qmm_n(
|
|
const device uint32_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
const device T* x [[buffer(3)]],
|
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
device T* y [[buffer(6)]],
|
|
const constant int& K [[buffer(7)]],
|
|
const constant int& N [[buffer(8)]],
|
|
const constant int& M [[buffer(9)]],
|
|
const constant int& x_batch_ndims [[buffer(10)]],
|
|
const constant int* x_shape [[buffer(11)]],
|
|
const constant int64_t* x_strides [[buffer(12)]],
|
|
const constant int& w_batch_ndims [[buffer(13)]],
|
|
const constant int* w_shape [[buffer(14)]],
|
|
const constant int64_t* w_strides [[buffer(15)]],
|
|
const constant int64_t* s_strides [[buffer(16)]],
|
|
const constant int64_t* b_strides [[buffer(17)]],
|
|
const constant int& batch_ndims [[buffer(18)]],
|
|
const constant int* batch_shape [[buffer(19)]],
|
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint lid [[thread_index_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
threadgroup T Xs[BM * BK_padded];
|
|
threadgroup T Ws[BK * BN_padded];
|
|
|
|
adjust_matrix_offsets<T>(
|
|
x,
|
|
w,
|
|
scales,
|
|
biases,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
y,
|
|
M * N,
|
|
batch_ndims,
|
|
batch_shape,
|
|
lhs_strides,
|
|
rhs_strides,
|
|
x_batch_ndims,
|
|
x_shape,
|
|
x_strides,
|
|
w_batch_ndims,
|
|
w_shape,
|
|
w_strides,
|
|
s_strides,
|
|
b_strides,
|
|
tid);
|
|
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
}
|
|
|
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_aligned(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const int k_iterations) {
|
|
for (int k = 0; k < k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Load elements into threadgroup memory
|
|
loader_a.load_unsafe();
|
|
loader_b.load_unsafe();
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
}
|
|
|
|
template <
|
|
bool rows_aligned,
|
|
bool cols_aligned,
|
|
bool transpose,
|
|
typename T,
|
|
typename mma_t,
|
|
typename loader_a_t,
|
|
typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_unaligned(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const int k_iterations,
|
|
const short tgp_bm,
|
|
const short tgp_bn,
|
|
const short tgp_bk) {
|
|
for (int k = 0; k < k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Load elements into threadgroup memory
|
|
if (rows_aligned) {
|
|
loader_a.load_unsafe();
|
|
} else {
|
|
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
|
}
|
|
if (cols_aligned) {
|
|
loader_b.load_unsafe();
|
|
} else {
|
|
loader_b.load_safe(
|
|
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
}
|
|
|
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_finalize(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const short2 tile_a,
|
|
const short2 tile_b) {
|
|
loader_a.load_safe(tile_a);
|
|
loader_b.load_safe(tile_b);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(As, Bs);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
int group_size,
|
|
int bits,
|
|
int BM,
|
|
int BN,
|
|
int BK,
|
|
int WM,
|
|
int WN,
|
|
bool transpose>
|
|
[[kernel]] void gather_qmm_rhs(
|
|
const device T* x [[buffer(0)]],
|
|
const device uint32_t* w [[buffer(1)]],
|
|
const device T* scales [[buffer(2)]],
|
|
const device T* biases [[buffer(3)]],
|
|
const device uint32_t* indices [[buffer(4)]],
|
|
device T* y [[buffer(5)]],
|
|
const constant int& M [[buffer(6)]],
|
|
const constant int& N [[buffer(7)]],
|
|
const constant int& K [[buffer(8)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
using mma_t = mlx::steel::BlockMMA<
|
|
T,
|
|
T,
|
|
BM,
|
|
BN,
|
|
BK,
|
|
WM,
|
|
WN,
|
|
false,
|
|
transpose,
|
|
BK_padded,
|
|
transpose ? BK_padded : BN_padded>;
|
|
using loader_x_t =
|
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
|
using loader_w_t = QuantizedBlockLoader<
|
|
T,
|
|
transpose ? BN : BK,
|
|
transpose ? BK : BN,
|
|
transpose ? BK_padded : BN_padded,
|
|
transpose,
|
|
WM * WN * SIMD_SIZE,
|
|
group_size,
|
|
bits>;
|
|
|
|
threadgroup T Xs[BM * BK_padded];
|
|
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
|
|
|
// Compute the block
|
|
const int K_w = K * bytes_per_pack / pack_factor;
|
|
const int K_g = K / group_size;
|
|
const int N_w = N * bytes_per_pack / pack_factor;
|
|
const int N_g = N / group_size;
|
|
const int K_it = K / BK;
|
|
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
|
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
|
const int y_row = tid.y * BM;
|
|
const int y_col = tid.x * BN;
|
|
const size_t y_row_long = size_t(y_row);
|
|
const size_t y_col_long = size_t(y_col);
|
|
|
|
// Prepare threadgroup bounds
|
|
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
|
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
|
|
|
// Calculate the final tiles in the case that K is not aligned
|
|
const int k_remain = K - K_it * BK;
|
|
const short2 tile_x = short2(k_remain, tgp_bm);
|
|
const short2 tile_w =
|
|
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
|
|
// Move x and output to the correct block
|
|
auto wl = (const device uint8_t*)w;
|
|
x += y_row_long * K;
|
|
y += y_row_long * N + y_col_long;
|
|
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
|
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
|
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
|
|
|
// Do as many matmuls as necessary
|
|
uint32_t index;
|
|
short offset;
|
|
uint32_t index_next = indices[y_row];
|
|
short offset_next = 0;
|
|
int n = 0;
|
|
while (n < tgp_bm) {
|
|
n++;
|
|
offset = offset_next;
|
|
index = index_next;
|
|
offset_next = tgp_bm;
|
|
for (; n < tgp_bm; n++) {
|
|
if (indices[y_row + n] != index) {
|
|
offset_next = n;
|
|
index_next = indices[y_row + n];
|
|
break;
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
// Prepare threadgroup mma operation
|
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
|
|
// Prepare threadgroup loading operations
|
|
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
|
thread loader_w_t loader_w(
|
|
wl + index * stride_w,
|
|
scales + index * stride_s,
|
|
biases + index * stride_s,
|
|
transpose ? K : N,
|
|
Ws,
|
|
simd_group_id,
|
|
simd_lane_id);
|
|
|
|
// Matrices are all aligned check nothing
|
|
if (align_M && align_N) {
|
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
|
if (!align_K) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
}
|
|
|
|
// Store results to device memory
|
|
if (offset_next - offset == BM) {
|
|
mma_op.store_result(y, N);
|
|
} else {
|
|
mma_op.store_result_slice(
|
|
y, N, short2(0, offset), short2(BN, offset_next));
|
|
}
|
|
} else {
|
|
// Tile aligned so check outside of the hot loop
|
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
|
if (!align_K) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
gemm_loop_finalize(
|
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
}
|
|
|
|
// Store results to device memory
|
|
if (offset_next - offset == BM) {
|
|
mma_op.store_result(y, N);
|
|
} else {
|
|
mma_op.store_result_slice(
|
|
y, N, short2(0, offset), short2(BN, offset_next));
|
|
}
|
|
}
|
|
|
|
// Tile partially aligned check rows
|
|
else if (align_N || tgp_bn == BN) {
|
|
gemm_loop_unaligned<false, true, transpose>(
|
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
if (!align_K) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
gemm_loop_finalize(
|
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
}
|
|
mma_op.store_result_slice(
|
|
y, N, short2(0, offset), short2(BN, offset_next));
|
|
}
|
|
|
|
// Tile partially aligned check cols
|
|
else if (align_M || tgp_bm == BM) {
|
|
gemm_loop_unaligned<true, false, transpose>(
|
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
if (!align_K) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
gemm_loop_finalize(
|
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
}
|
|
mma_op.store_result_slice(
|
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
|
}
|
|
|
|
// Nothing aligned so check both rows and cols
|
|
else {
|
|
gemm_loop_unaligned<false, false, transpose>(
|
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
if (!align_K) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
gemm_loop_finalize(
|
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
}
|
|
mma_op.store_result_slice(
|
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits>
|
|
[[kernel]] void affine_quantize(
|
|
const device T* w [[buffer(0)]],
|
|
device uint8_t* out [[buffer(1)]],
|
|
device T* scales [[buffer(2)]],
|
|
device T* biases [[buffer(3)]],
|
|
uint2 index [[thread_position_in_grid]],
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
constexpr float eps = 1e-7;
|
|
constexpr int simd_size = 32;
|
|
constexpr float n_bins = (1 << bits) - 1;
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
constexpr int values_per_reduce = group_size / simd_size;
|
|
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
|
|
constexpr int writes_per_pack =
|
|
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
static_assert(
|
|
group_size % simd_size == 0,
|
|
"Group size must be divisible by simd size.");
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
size_t in_index = offset * values_per_reduce;
|
|
size_t out_index = power_of_2_bits
|
|
? offset * writes_per_pack
|
|
: offset * bytes_per_pack / writes_per_reduce;
|
|
|
|
float w_thread[values_per_reduce];
|
|
float w_min = Limits<T>::max;
|
|
float w_max = 0;
|
|
|
|
#pragma clang loop unroll(full)
|
|
for (int i = 0; i < values_per_reduce; i++) {
|
|
float val = w[in_index + i];
|
|
w_thread[i] = val;
|
|
w_min = min(w_min, val);
|
|
w_max = max(w_max, val);
|
|
}
|
|
|
|
w_min = simd_min(w_min);
|
|
w_max = simd_max(w_max);
|
|
|
|
float scale = max((w_max - w_min) / n_bins, eps);
|
|
bool side = abs(w_min) > abs(w_max);
|
|
scale = side ? scale : -scale;
|
|
float edge = side ? w_min : w_max;
|
|
float q0 = round(edge / scale);
|
|
bool at_zero = q0 == 0.0f;
|
|
scale = at_zero ? scale : edge / q0;
|
|
float bias = at_zero ? 0 : edge;
|
|
|
|
// Write out the scales and biases
|
|
size_t gindex = in_index / group_size;
|
|
if (in_index % group_size == 0) {
|
|
scales[gindex] = static_cast<T>(scale);
|
|
biases[gindex] = static_cast<T>(bias);
|
|
}
|
|
|
|
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
|
|
OutType output = 0;
|
|
|
|
#pragma clang loop unroll(full)
|
|
for (int i = 0; i < values_per_reduce; i++) {
|
|
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
|
if (bits == 8) {
|
|
output = val;
|
|
} else {
|
|
output |= val << (bits * (i % pack_factor));
|
|
}
|
|
|
|
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
|
|
out[out_index + i / pack_factor] = output;
|
|
output = 0;
|
|
} else {
|
|
#pragma clang loop unroll(full)
|
|
for (int j = 1; j < writes_per_reduce; j++) {
|
|
uint8_t sval = simd_shuffle_down(val, j);
|
|
output |= static_cast<OutType>(sval)
|
|
<< (bits * (j * values_per_reduce + i));
|
|
}
|
|
}
|
|
}
|
|
if (bits == 3 || bits == 6) {
|
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
out[out_index] = output & 0xff;
|
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
}
|
|
} else if (bits == 5) {
|
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
out[out_index] = output & 0xff;
|
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
out[out_index + 3] = (output & 0xff000000) >> 24;
|
|
out[out_index + 4] = (output & 0xff00000000) >> 32;
|
|
}
|
|
} else {
|
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
|
out[out_index / writes_per_reduce] = output;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, const int group_size, const int bits>
|
|
[[kernel]] void affine_dequantize(
|
|
const device uint8_t* w [[buffer(0)]],
|
|
const device T* scales [[buffer(1)]],
|
|
const device T* biases [[buffer(2)]],
|
|
device T* out [[buffer(3)]],
|
|
uint2 index [[thread_position_in_grid]],
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
size_t oindex = offset * pack_factor;
|
|
size_t gindex = oindex / group_size;
|
|
T scale = scales[gindex];
|
|
T bias = biases[gindex];
|
|
|
|
out += oindex;
|
|
|
|
if (bits == 3) {
|
|
w += offset * bytes_per_pack;
|
|
out[0] = (w[0] & 0x7) * scale + bias;
|
|
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
} else if (bits == 5) {
|
|
w += offset * bytes_per_pack;
|
|
out[0] = (w[0] & 0x1f) * scale + bias;
|
|
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
} else if (bits == 6) {
|
|
w += offset * bytes_per_pack;
|
|
out[0] = (w[0] & 0x3f) * scale + bias;
|
|
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
} else {
|
|
uint val = w[offset];
|
|
#pragma clang loop unroll(full)
|
|
for (int i = 0; i < pack_factor; i++) {
|
|
uint8_t d;
|
|
if (bits == 2) {
|
|
d = (val >> (bits * i)) & 0x03;
|
|
} else if (bits == 4) {
|
|
d = (val >> (bits * i)) & 0x0f;
|
|
} else if (bits == 8) {
|
|
d = val;
|
|
}
|
|
out[i] = scale * d + bias;
|
|
}
|
|
}
|
|
}
|