|
|
|
@@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4;
|
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
|
|
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U sum = 0;
|
|
|
|
|
|
|
|
|
@@ -28,6 +28,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i += 8) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
|
|
|
x[i + 6] + x[i + 7];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
|
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
|
|
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
|
|
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
|
|
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 4) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
|
|
@@ -38,6 +53,16 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 8) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i++) {
|
|
|
|
|
sum += x[i];
|
|
|
|
@@ -51,8 +76,8 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
|
|
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U sum = 0;
|
|
|
|
|
|
|
|
|
@@ -64,8 +89,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
|
|
|
|
}
|
|
|
|
|
for (int i = N; i < values_per_thread; i++) {
|
|
|
|
|
x_thread[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < N; i += 8) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
|
|
|
x[i + 6] + x[i + 7];
|
|
|
|
|
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
|
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
|
|
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
|
|
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
|
|
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -77,8 +115,15 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
|
|
|
|
}
|
|
|
|
|
for (int i = N; i < values_per_thread; i++) {
|
|
|
|
|
x_thread[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < N; i += 4) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -87,9 +132,10 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
sum += x[i];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = N; i < values_per_thread; i++) {
|
|
|
|
|
x_thread[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = N; i < values_per_thread; i++) {
|
|
|
|
|
x_thread[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return sum;
|
|
|
|
@@ -103,8 +149,8 @@ inline U qdot(
|
|
|
|
|
U bias,
|
|
|
|
|
U sum) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U accum = 0;
|
|
|
|
|
|
|
|
|
@@ -118,6 +164,26 @@ inline U qdot(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
|
|
|
x_thread += 8 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x07) * x_thread[0];
|
|
|
|
|
accum += (w[0] & 0x38) * x_thread[1];
|
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[2];
|
|
|
|
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[1] & 0x0e) * x_thread[3];
|
|
|
|
|
accum += (w[1] & 0x70) * x_thread[4];
|
|
|
|
|
accum += (w[1] & 0x80) * x_thread[5];
|
|
|
|
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[2] & 0x1c) * x_thread[6];
|
|
|
|
|
accum += (w[2] & 0xe0) * x_thread[7];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 4) {
|
|
|
|
|
const device uint16_t* ws = (const device uint16_t*)w;
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
|
|
@@ -129,6 +195,23 @@ inline U qdot(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
|
|
|
x_thread += 4 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x3f) * x_thread[0];
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[1];
|
|
|
|
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[1] & 0xf0) * x_thread[2];
|
|
|
|
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[2] & 0xfc) * x_thread[3];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 8) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i++) {
|
|
|
|
|
accum += x_thread[i] * w[i];
|
|
|
|
@@ -147,8 +230,8 @@ inline U qdot_safe(
|
|
|
|
|
U sum,
|
|
|
|
|
int N) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U accum = 0;
|
|
|
|
|
|
|
|
|
@@ -162,6 +245,26 @@ inline U qdot_safe(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
|
|
|
x_thread += 8 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x07) * x_thread[0];
|
|
|
|
|
accum += (w[0] & 0x38) * x_thread[1];
|
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[2];
|
|
|
|
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[1] & 0x0e) * x_thread[3];
|
|
|
|
|
accum += (w[1] & 0x70) * x_thread[4];
|
|
|
|
|
accum += (w[1] & 0x80) * x_thread[5];
|
|
|
|
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[2] & 0x1c) * x_thread[6];
|
|
|
|
|
accum += (w[2] & 0xe0) * x_thread[7];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 4) {
|
|
|
|
|
const device uint16_t* ws = (const device uint16_t*)w;
|
|
|
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
|
|
@@ -173,6 +276,23 @@ inline U qdot_safe(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
|
|
|
x_thread += 4 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x3f) * x_thread[0];
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0xc0) * x_thread[1];
|
|
|
|
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[1] & 0xf0) * x_thread[2];
|
|
|
|
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
|
|
|
|
|
|
|
|
accum += (w[2] & 0xfc) * x_thread[3];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 8) {
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
accum += x_thread[i] * w[i];
|
|
|
|
@@ -186,8 +306,8 @@ template <typename U, int values_per_thread, int bits>
|
|
|
|
|
inline void
|
|
|
|
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
|
|
|
@@ -199,12 +319,45 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
|
|
|
uint8_t w0 = w[3 * i];
|
|
|
|
|
uint8_t w1 = w[3 * i + 1];
|
|
|
|
|
uint8_t w2 = w[3 * i + 2];
|
|
|
|
|
|
|
|
|
|
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
|
|
|
|
|
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
|
|
|
|
|
result[8 * i + 2] +=
|
|
|
|
|
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
|
|
|
|
|
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
|
|
|
|
|
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
|
|
|
|
|
result[8 * i + 5] +=
|
|
|
|
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
|
|
|
|
|
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
|
|
|
|
|
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 4) {
|
|
|
|
|
U s[2] = {scale, scale / 16.0f};
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
|
|
|
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
|
|
|
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
|
|
|
uint8_t w0 = w[3 * i];
|
|
|
|
|
uint8_t w1 = w[3 * i + 1];
|
|
|
|
|
uint8_t w2 = w[3 * i + 2];
|
|
|
|
|
|
|
|
|
|
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
|
|
|
|
|
result[4 * i + 1] +=
|
|
|
|
|
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
|
|
|
|
|
result[4 * i + 2] +=
|
|
|
|
|
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
|
|
|
|
|
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 8) {
|
|
|
|
@@ -218,8 +371,8 @@ template <typename U, int N, int bits>
|
|
|
|
|
inline void
|
|
|
|
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
U s[4] = {
|
|
|
|
@@ -235,6 +388,22 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 3) {
|
|
|
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
|
|
|
w_local += 8 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
w_local[0] = (w[0] & 0x7) * scale + bias;
|
|
|
|
|
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
|
|
|
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
|
|
|
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
|
|
|
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
|
|
|
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
|
|
|
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
|
|
|
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 4) {
|
|
|
|
|
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
|
|
|
|
for (int i = 0; i < (N / 2); i++) {
|
|
|
|
@@ -243,6 +412,18 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
|
|
|
w_local += 4 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
|
|
|
|
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
|
|
|
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
|
|
|
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 8) {
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
w_local[i] = scale * w[i] + bias;
|
|
|
|
@@ -267,10 +448,11 @@ struct QuantizedBlockLoader {
|
|
|
|
|
group_size % BCOLS == 0,
|
|
|
|
|
"The group size should be divisible by the columns");
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 4 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 4, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
|
|
|
|
|
MLX_MTL_CONST short pack_factor = 32 / bits;
|
|
|
|
|
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
|
|
|
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
|
|
|
|
MLX_MTL_CONST short n_reads =
|
|
|
|
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
|
|
|
@@ -286,12 +468,12 @@ struct QuantizedBlockLoader {
|
|
|
|
|
const short bj;
|
|
|
|
|
|
|
|
|
|
threadgroup T* dst;
|
|
|
|
|
const device uint32_t* src;
|
|
|
|
|
const device uint8_t* src;
|
|
|
|
|
const device T* scales;
|
|
|
|
|
const device T* biases;
|
|
|
|
|
|
|
|
|
|
QuantizedBlockLoader(
|
|
|
|
|
const device uint32_t* src_,
|
|
|
|
|
const device uint8_t* src_,
|
|
|
|
|
const device T* scales_,
|
|
|
|
|
const device T* biases_,
|
|
|
|
|
const int src_ld_,
|
|
|
|
@@ -300,14 +482,16 @@ struct QuantizedBlockLoader {
|
|
|
|
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
|
|
|
: src_ld(src_ld_),
|
|
|
|
|
tile_stride(
|
|
|
|
|
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
|
|
|
|
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
|
|
|
|
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
|
|
|
|
group_step_cnt(0),
|
|
|
|
|
group_stride(BROWS * src_ld / group_size),
|
|
|
|
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
|
|
|
bi(n_reads * thread_idx / BCOLS_PACKED),
|
|
|
|
|
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
|
|
|
|
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
|
|
|
|
src(src_ + bi * src_ld / pack_factor + bj),
|
|
|
|
|
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
|
|
|
|
bj * bytes_per_pack),
|
|
|
|
|
scales(scales_ + bi * src_ld / group_size),
|
|
|
|
|
biases(biases_ + bi * src_ld / group_size) {}
|
|
|
|
|
|
|
|
|
@@ -320,7 +504,7 @@ struct QuantizedBlockLoader {
|
|
|
|
|
T bias = *biases;
|
|
|
|
|
for (int i = 0; i < n_reads; i++) {
|
|
|
|
|
dequantize<T, pack_factor, bits>(
|
|
|
|
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
|
|
|
|
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -347,7 +531,10 @@ struct QuantizedBlockLoader {
|
|
|
|
|
T bias = *biases;
|
|
|
|
|
for (int i = 0; i < n_reads; i++) {
|
|
|
|
|
dequantize<T, pack_factor, bits>(
|
|
|
|
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
|
|
|
|
(device uint8_t*)(src + i * bytes_per_pack),
|
|
|
|
|
scale,
|
|
|
|
|
bias,
|
|
|
|
|
dst + i * pack_factor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -410,8 +597,7 @@ METAL_FUNC void qmv_quad_impl(
|
|
|
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < results_per_quadgroup; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
|
|
|
|
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
|
|
|
|
|
|
|
|
@@ -442,25 +628,34 @@ METAL_FUNC void qmv_fast_impl(
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int packs_per_thread = bits > 2 ? 2 : 1;
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int results_per_simdgroup = 4;
|
|
|
|
|
constexpr int pack_factor = 32 / bits;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
|
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
|
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
|
|
|
|
|
|
// When bits is a power of two, read 1 uint32_t at a time
|
|
|
|
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
|
|
|
|
using W_T =
|
|
|
|
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
|
|
|
|
const device W_T* ws = (const device W_T*)w;
|
|
|
|
|
|
|
|
|
|
typedef float U;
|
|
|
|
|
|
|
|
|
|
thread U x_thread[values_per_thread];
|
|
|
|
|
thread U result[results_per_simdgroup] = {0};
|
|
|
|
|
|
|
|
|
|
// Adjust positions
|
|
|
|
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
|
|
|
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
|
|
|
const int in_vec_size_g = in_vec_size / group_size;
|
|
|
|
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
|
|
|
|
simd_gid * results_per_simdgroup;
|
|
|
|
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
|
|
|
|
|
|
|
|
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
|
|
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
|
|
|
@@ -470,8 +665,7 @@ METAL_FUNC void qmv_fast_impl(
|
|
|
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
|
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
|
|
|
|
|
@@ -480,7 +674,7 @@ METAL_FUNC void qmv_fast_impl(
|
|
|
|
|
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
w += block_size / pack_factor;
|
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
|
|
|
scales += block_size / group_size;
|
|
|
|
|
biases += block_size / group_size;
|
|
|
|
|
x += block_size;
|
|
|
|
@@ -506,21 +700,29 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int results_per_simdgroup = 4;
|
|
|
|
|
constexpr int packs_per_thread = 1;
|
|
|
|
|
constexpr int pack_factor = 32 / bits;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
|
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
|
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
|
|
|
|
|
|
// When bits is a power of two, read 1 uint32_t at a time
|
|
|
|
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
|
|
|
|
using W_T =
|
|
|
|
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
|
|
|
|
const device W_T* ws = (const device W_T*)w;
|
|
|
|
|
|
|
|
|
|
typedef float U;
|
|
|
|
|
|
|
|
|
|
thread U x_thread[values_per_thread];
|
|
|
|
|
thread U result[results_per_simdgroup] = {0};
|
|
|
|
|
|
|
|
|
|
// Adjust positions
|
|
|
|
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
|
|
|
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
|
|
|
const int in_vec_size_g = in_vec_size / group_size;
|
|
|
|
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
|
|
|
|
simd_gid * results_per_simdgroup;
|
|
|
|
@@ -533,7 +735,8 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
// In this case we need to properly guard all our reads because there isn't
|
|
|
|
|
// even 1 tile in the matrix
|
|
|
|
|
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
|
|
|
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
|
|
|
|
ws +=
|
|
|
|
|
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
|
|
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
|
|
|
@@ -544,8 +747,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
|
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
|
|
|
|
|
@@ -555,7 +757,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
w += block_size / pack_factor;
|
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
|
|
|
scales += block_size / group_size;
|
|
|
|
|
biases += block_size / group_size;
|
|
|
|
|
x += block_size;
|
|
|
|
@@ -569,8 +771,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
x, x_thread, remaining);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
|
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
|
|
|
|
|
@@ -591,7 +792,8 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
|
|
|
|
|
// In this case the last tile is moved back to redo some output values
|
|
|
|
|
else {
|
|
|
|
|
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
|
|
|
|
ws += used_out_row * in_vec_size_w +
|
|
|
|
|
simd_lid * packs_per_thread * bytes_per_pack;
|
|
|
|
|
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
|
|
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
|
|
|
@@ -602,8 +804,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
|
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
|
|
|
|
|
@@ -613,7 +814,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
w += block_size / pack_factor;
|
|
|
|
|
ws += block_size * bytes_per_pack / pack_factor;
|
|
|
|
|
scales += block_size / group_size;
|
|
|
|
|
biases += block_size / group_size;
|
|
|
|
|
x += block_size;
|
|
|
|
@@ -627,8 +828,7 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
x, x_thread, remaining);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
|
|
|
const device uint8_t* wl =
|
|
|
|
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
|
|
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
|
|
|
const device T* sl = scales + row * in_vec_size_g;
|
|
|
|
|
const device T* bl = biases + row * in_vec_size_g;
|
|
|
|
|
|
|
|
|
@@ -659,14 +859,22 @@ METAL_FUNC void qvm_impl(
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int pack_factor = 32 / bits;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
constexpr int tn = 32 / pack_factor;
|
|
|
|
|
constexpr int blocksize = SIMD_SIZE;
|
|
|
|
|
constexpr int block_size = SIMD_SIZE;
|
|
|
|
|
|
|
|
|
|
// When bits is a power of two, read 1 uint32_t at a time
|
|
|
|
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
|
|
|
|
using W_T =
|
|
|
|
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
|
|
|
|
const device W_T* ws = (const device W_T*)w;
|
|
|
|
|
|
|
|
|
|
typedef float U;
|
|
|
|
|
typedef struct {
|
|
|
|
|
uint32_t wi[tn];
|
|
|
|
|
W_T wi[tn * bytes_per_pack];
|
|
|
|
|
} vec_w;
|
|
|
|
|
|
|
|
|
|
thread vec_w w_local;
|
|
|
|
@@ -676,11 +884,10 @@ METAL_FUNC void qvm_impl(
|
|
|
|
|
thread U x_local = 0;
|
|
|
|
|
|
|
|
|
|
// Adjust positions
|
|
|
|
|
const int out_vec_size_w = out_vec_size / pack_factor;
|
|
|
|
|
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
|
|
|
|
const int out_vec_size_g = out_vec_size / group_size;
|
|
|
|
|
int out_col =
|
|
|
|
|
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
|
|
|
|
|
w += out_col / pack_factor + simd_lid * out_vec_size_w;
|
|
|
|
|
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
|
|
|
|
|
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
|
|
|
|
|
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
|
|
|
|
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
|
|
|
|
x += tid.y * in_vec_size + simd_lid;
|
|
|
|
@@ -690,43 +897,42 @@ METAL_FUNC void qvm_impl(
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Loop over in_vec in blocks of blocksize
|
|
|
|
|
int remaining = in_vec_size % blocksize;
|
|
|
|
|
// Loop over in_vec in blocks of block_size
|
|
|
|
|
int remaining = in_vec_size % block_size;
|
|
|
|
|
if (remaining == 0) {
|
|
|
|
|
for (int i = 0; i < in_vec_size; i += blocksize) {
|
|
|
|
|
for (int i = 0; i < in_vec_size; i += block_size) {
|
|
|
|
|
x_local = *x;
|
|
|
|
|
scale = *scales;
|
|
|
|
|
bias = *biases;
|
|
|
|
|
w_local = *((device vec_w*)w);
|
|
|
|
|
|
|
|
|
|
w_local = *((device vec_w*)ws);
|
|
|
|
|
qouter<U, tn * pack_factor, bits>(
|
|
|
|
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
|
|
|
|
|
|
|
|
x += blocksize;
|
|
|
|
|
scales += blocksize * out_vec_size_g;
|
|
|
|
|
biases += blocksize * out_vec_size_g;
|
|
|
|
|
w += blocksize * out_vec_size_w;
|
|
|
|
|
x += block_size;
|
|
|
|
|
scales += block_size * out_vec_size_g;
|
|
|
|
|
biases += block_size * out_vec_size_g;
|
|
|
|
|
ws += block_size * out_vec_size_w;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = blocksize; i < in_vec_size; i += blocksize) {
|
|
|
|
|
for (int i = block_size; i < in_vec_size; i += block_size) {
|
|
|
|
|
x_local = *x;
|
|
|
|
|
scale = *scales;
|
|
|
|
|
bias = *biases;
|
|
|
|
|
w_local = *((device vec_w*)w);
|
|
|
|
|
w_local = *((device vec_w*)ws);
|
|
|
|
|
|
|
|
|
|
qouter<U, tn * pack_factor, bits>(
|
|
|
|
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
|
|
|
|
|
|
|
|
x += blocksize;
|
|
|
|
|
scales += blocksize * out_vec_size_g;
|
|
|
|
|
biases += blocksize * out_vec_size_g;
|
|
|
|
|
w += blocksize * out_vec_size_w;
|
|
|
|
|
x += block_size;
|
|
|
|
|
scales += block_size * out_vec_size_g;
|
|
|
|
|
biases += block_size * out_vec_size_g;
|
|
|
|
|
ws += block_size * out_vec_size_w;
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<int>(simd_lid) < remaining) {
|
|
|
|
|
x_local = *x;
|
|
|
|
|
scale = *scales;
|
|
|
|
|
bias = *biases;
|
|
|
|
|
w_local = *((device vec_w*)w);
|
|
|
|
|
w_local = *((device vec_w*)ws);
|
|
|
|
|
} else {
|
|
|
|
|
x_local = 0;
|
|
|
|
|
scale = 0;
|
|
|
|
@@ -781,8 +987,9 @@ METAL_FUNC void qmm_t_impl(
|
|
|
|
|
|
|
|
|
|
constexpr int WM = 2;
|
|
|
|
|
constexpr int WN = 2;
|
|
|
|
|
constexpr int pack_factor = 32 / bits;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
|
|
|
|
|
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
|
|
|
using mma_t = mlx::steel::
|
|
|
|
@@ -800,13 +1007,15 @@ METAL_FUNC void qmm_t_impl(
|
|
|
|
|
bits>;
|
|
|
|
|
|
|
|
|
|
// Set the block
|
|
|
|
|
const int K_w = K / pack_factor;
|
|
|
|
|
const int K_w = K * bytes_per_pack / pack_factor;
|
|
|
|
|
const int K_g = K / group_size;
|
|
|
|
|
const int y_row = tid.y * BM;
|
|
|
|
|
const int y_col = tid.x * BN;
|
|
|
|
|
|
|
|
|
|
auto wl = (const device uint8_t*)w;
|
|
|
|
|
|
|
|
|
|
x += y_row * K;
|
|
|
|
|
w += y_col * K_w;
|
|
|
|
|
wl += y_col * K_w;
|
|
|
|
|
scales += y_col * K_g;
|
|
|
|
|
biases += y_col * K_g;
|
|
|
|
|
y += y_row * N + y_col;
|
|
|
|
@@ -815,7 +1024,7 @@ METAL_FUNC void qmm_t_impl(
|
|
|
|
|
const short num_els = min(BM, M - y_row);
|
|
|
|
|
const short num_outs = min(BN, N - y_col);
|
|
|
|
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
|
|
|
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
|
|
|
|
|
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
|
|
|
|
|
mma_t mma_op(simd_gid, simd_lid);
|
|
|
|
|
|
|
|
|
|
if (num_els < BM) {
|
|
|
|
@@ -857,6 +1066,7 @@ METAL_FUNC void qmm_t_impl(
|
|
|
|
|
loader_x.load_unsafe();
|
|
|
|
|
loader_w.load_unsafe();
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
mma_op.mma(Xs, Ws);
|
|
|
|
|
loader_x.next();
|
|
|
|
|
loader_w.next();
|
|
|
|
@@ -902,9 +1112,11 @@ METAL_FUNC void qmm_n_impl(
|
|
|
|
|
|
|
|
|
|
constexpr int WM = 2;
|
|
|
|
|
constexpr int WN = 2;
|
|
|
|
|
constexpr int pack_factor = 32 / bits;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
|
|
|
using mma_t = mlx::steel::
|
|
|
|
@@ -921,11 +1133,13 @@ METAL_FUNC void qmm_n_impl(
|
|
|
|
|
group_size,
|
|
|
|
|
bits>;
|
|
|
|
|
|
|
|
|
|
auto wl = (const device uint8_t*)w;
|
|
|
|
|
|
|
|
|
|
// Set the block
|
|
|
|
|
const int y_row = tid.y * BM;
|
|
|
|
|
const int y_col = tid.x * BN;
|
|
|
|
|
x += y_row * K;
|
|
|
|
|
w += y_col / pack_factor;
|
|
|
|
|
wl += y_col * bytes_per_pack / pack_factor;
|
|
|
|
|
scales += y_col / group_size;
|
|
|
|
|
biases += y_col / group_size;
|
|
|
|
|
y += y_row * N + y_col;
|
|
|
|
@@ -933,7 +1147,7 @@ METAL_FUNC void qmm_n_impl(
|
|
|
|
|
// Make the x loader and mma operation
|
|
|
|
|
const short num_els = min(BM, M - y_row);
|
|
|
|
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
|
|
|
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
|
|
|
|
|
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
|
|
|
|
|
mma_t mma_op(simd_gid, simd_lid);
|
|
|
|
|
|
|
|
|
|
if (num_els < BM) {
|
|
|
|
@@ -1805,13 +2019,14 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
|
|
|
constexpr T eps = T(1e-7);
|
|
|
|
|
constexpr int simd_size = 32;
|
|
|
|
|
constexpr int uint8_bits = 8;
|
|
|
|
|
constexpr T n_bins = (1 << bits) - 1;
|
|
|
|
|
constexpr int packs_per_int = uint8_bits / bits;
|
|
|
|
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int values_per_reduce = group_size / simd_size;
|
|
|
|
|
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
|
|
|
|
constexpr int writes_per_pack =
|
|
|
|
|
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
static_assert(
|
|
|
|
|
group_size % simd_size == 0,
|
|
|
|
@@ -1819,7 +2034,9 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
|
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
|
|
|
size_t in_index = offset * values_per_reduce;
|
|
|
|
|
size_t out_index = offset * writes_per_pack;
|
|
|
|
|
size_t out_index = power_of_2_bits
|
|
|
|
|
? offset * writes_per_pack
|
|
|
|
|
: offset * bytes_per_pack / writes_per_reduce;
|
|
|
|
|
|
|
|
|
|
T w_thread[values_per_reduce];
|
|
|
|
|
T w_min = Limits<T>::max;
|
|
|
|
@@ -1852,7 +2069,11 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
biases[gindex] = bias;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint8_t output = 0;
|
|
|
|
|
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
|
|
|
|
using OutT =
|
|
|
|
|
typename ConditionalType<power_of_2_bits, uint8_t, uint32_t>::type;
|
|
|
|
|
OutT output = 0;
|
|
|
|
|
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int i = 0; i < values_per_reduce; i++) {
|
|
|
|
|
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
|
|
|
@@ -1868,47 +2089,23 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
output = 0;
|
|
|
|
|
} else {
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int j = 0; j < writes_per_reduce - 1; j++) {
|
|
|
|
|
uint8_t sval = simd_shuffle_down(val, j + 1);
|
|
|
|
|
output += sval << (bits * (values_per_reduce + j + i));
|
|
|
|
|
for (int j = 1; j < writes_per_reduce; j++) {
|
|
|
|
|
uint8_t sval = simd_shuffle_down(val, j);
|
|
|
|
|
output += sval << (bits * (j * values_per_reduce + i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
|
|
|
|
out[out_index / writes_per_reduce] = output;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, const int group_size, const int bits>
|
|
|
|
|
[[kernel]] void affine_quantize_scales_biases(
|
|
|
|
|
const device T* w [[buffer(0)]],
|
|
|
|
|
const device T* scales [[buffer(1)]],
|
|
|
|
|
const device T* biases [[buffer(2)]],
|
|
|
|
|
device uint8_t* out [[buffer(3)]],
|
|
|
|
|
uint2 index [[thread_position_in_grid]],
|
|
|
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
|
|
|
constexpr int uint8_bits = 8;
|
|
|
|
|
constexpr int packs_per_int = uint8_bits / bits;
|
|
|
|
|
constexpr T n_bins = (1 << bits) - 1;
|
|
|
|
|
|
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
|
|
|
size_t in_index = offset * packs_per_int;
|
|
|
|
|
size_t gindex = in_index / group_size;
|
|
|
|
|
|
|
|
|
|
T scale = scales[gindex];
|
|
|
|
|
T bias = biases[gindex];
|
|
|
|
|
|
|
|
|
|
uint8_t output = 0;
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int i = 0; i < packs_per_int; i++) {
|
|
|
|
|
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
|
|
|
|
|
if (bits == 8) {
|
|
|
|
|
output = val;
|
|
|
|
|
} else {
|
|
|
|
|
output += val << (bits * i);
|
|
|
|
|
if (bits == 3 || bits == 6) {
|
|
|
|
|
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
|
|
|
|
|
out[out_index] = output & 0xff;
|
|
|
|
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
|
|
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
|
|
|
|
out[out_index / writes_per_reduce] = output;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
out[offset] = output;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, const int group_size, const int bits>
|
|
|
|
@@ -1919,26 +2116,48 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
device T* out [[buffer(3)]],
|
|
|
|
|
uint2 index [[thread_position_in_grid]],
|
|
|
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
|
|
|
constexpr int uint8_bits = 8;
|
|
|
|
|
constexpr int packs_per_int = uint8_bits / bits;
|
|
|
|
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
|
|
|
size_t oindex = offset * packs_per_int;
|
|
|
|
|
size_t gindex = oindex / group_size;
|
|
|
|
|
T scale = scales[gindex];
|
|
|
|
|
T bias = biases[gindex];
|
|
|
|
|
uint val = w[offset];
|
|
|
|
|
|
|
|
|
|
out += oindex;
|
|
|
|
|
|
|
|
|
|
if (bits == 3) {
|
|
|
|
|
w += offset * bytes_per_pack;
|
|
|
|
|
out[0] = (w[0] & 0x7) * scale + bias;
|
|
|
|
|
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
|
|
|
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
|
|
|
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
|
|
|
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
|
|
|
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
|
|
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
|
|
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
|
|
|
|
|
|
|
|
} else if (bits == 6) {
|
|
|
|
|
w += offset * bytes_per_pack;
|
|
|
|
|
out[0] = (w[0] & 0x3f) * scale + bias;
|
|
|
|
|
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
|
|
|
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
|
|
|
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
|
|
|
} else {
|
|
|
|
|
uint val = w[offset];
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int i = 0; i < packs_per_int; i++) {
|
|
|
|
|
uint8_t d;
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
d = (val >> (bits * i)) & 0x03;
|
|
|
|
|
} else if (bits == 4) {
|
|
|
|
|
d = (val >> (bits * i)) & 0x0f;
|
|
|
|
|
} else if (bits == 8) {
|
|
|
|
|
d = val;
|
|
|
|
|
for (int i = 0; i < packs_per_int; i++) {
|
|
|
|
|
uint8_t d;
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
d = (val >> (bits * i)) & 0x03;
|
|
|
|
|
} else if (bits == 4) {
|
|
|
|
|
d = (val >> (bits * i)) & 0x0f;
|
|
|
|
|
} else if (bits == 8) {
|
|
|
|
|
d = val;
|
|
|
|
|
}
|
|
|
|
|
out[i] = scale * d + bias;
|
|
|
|
|
}
|
|
|
|
|
out[oindex + i] = scale * d + bias;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|