|
|
|
@@ -14,11 +14,23 @@ using namespace metal;
|
|
|
|
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
|
|
|
|
MLX_MTL_CONST int QUAD_SIZE = 4;
|
|
|
|
|
|
|
|
|
|
template <int bits, int wsize = 8>
|
|
|
|
|
inline constexpr short get_pack_factor() {
|
|
|
|
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int bits, int wsize = 8>
|
|
|
|
|
inline constexpr short get_bytes_per_pack() {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
|
|
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U sum = 0;
|
|
|
|
|
|
|
|
|
@@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i += 8) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
|
|
|
x[i + 6] + x[i + 7];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
|
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
|
|
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
|
|
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
|
|
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < values_per_thread; i += 4) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
|
|
@@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|
|
|
|
template <typename T, typename U, int values_per_thread, int bits>
|
|
|
|
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U sum = 0;
|
|
|
|
|
|
|
|
|
@@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < N; i += 8) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
|
|
|
x[i + 6] + x[i + 7];
|
|
|
|
|
x_thread[i] = x[i];
|
|
|
|
|
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
|
|
|
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
|
|
|
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
|
|
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
|
|
|
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
|
|
|
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
|
|
|
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < N; i += 4) {
|
|
|
|
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
|
|
@@ -153,8 +196,9 @@ inline U qdot(
|
|
|
|
|
U bias,
|
|
|
|
|
U sum) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U accum = 0;
|
|
|
|
|
|
|
|
|
@@ -199,6 +243,26 @@ inline U qdot(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
|
|
|
x_thread += 8 * i;
|
|
|
|
|
w += 5 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x1f) * x_thread[0];
|
|
|
|
|
accum += (w[0] & 0xe0) * x_thread[1];
|
|
|
|
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
|
|
|
accum += (w[1] & 0x7c) * x_thread[2];
|
|
|
|
|
accum += (w[1] & 0x80) * x_thread[3];
|
|
|
|
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
|
|
|
accum += (w[2] & 0xf0) * x_thread[4];
|
|
|
|
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
|
|
|
accum += (w[3] & 0x3e) * x_thread[5];
|
|
|
|
|
accum += (w[3] & 0xc0) * x_thread[6];
|
|
|
|
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
|
|
|
accum += (w[4] & 0xf8) * x_thread[7];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
|
|
|
x_thread += 4 * i;
|
|
|
|
@@ -234,8 +298,9 @@ inline U qdot_safe(
|
|
|
|
|
U sum,
|
|
|
|
|
int N) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
U accum = 0;
|
|
|
|
|
|
|
|
|
@@ -280,6 +345,26 @@ inline U qdot_safe(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
|
|
|
x_thread += 8 * i;
|
|
|
|
|
w += 5 * i;
|
|
|
|
|
|
|
|
|
|
accum += (w[0] & 0x1f) * x_thread[0];
|
|
|
|
|
accum += (w[0] & 0xe0) * x_thread[1];
|
|
|
|
|
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
|
|
|
accum += (w[1] & 0x7c) * x_thread[2];
|
|
|
|
|
accum += (w[1] & 0x80) * x_thread[3];
|
|
|
|
|
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
|
|
|
accum += (w[2] & 0xf0) * x_thread[4];
|
|
|
|
|
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
|
|
|
accum += (w[3] & 0x3e) * x_thread[5];
|
|
|
|
|
accum += (w[3] & 0xc0) * x_thread[6];
|
|
|
|
|
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
|
|
|
accum += (w[4] & 0xf8) * x_thread[7];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
|
|
|
x_thread += 4 * i;
|
|
|
|
@@ -310,8 +395,9 @@ template <typename U, int values_per_thread, int bits>
|
|
|
|
|
inline void
|
|
|
|
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
|
|
|
@@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
|
|
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
|
|
|
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else if (bits == 6) {
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
|
|
|
uint8_t w0 = w[5 * i];
|
|
|
|
|
uint8_t w1 = w[5 * i + 1];
|
|
|
|
|
uint8_t w2 = w[5 * i + 2];
|
|
|
|
|
uint8_t w3 = w[5 * i + 3];
|
|
|
|
|
uint8_t w4 = w[5 * i + 4];
|
|
|
|
|
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
|
|
|
|
|
result[8 * i + 1] +=
|
|
|
|
|
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
|
|
|
|
|
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
|
|
|
|
|
result[8 * i + 3] +=
|
|
|
|
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
|
|
|
|
|
result[8 * i + 4] +=
|
|
|
|
|
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
|
|
|
|
|
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
|
|
|
|
|
result[8 * i + 6] +=
|
|
|
|
|
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
|
|
|
|
|
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
|
|
|
uint8_t w0 = w[3 * i];
|
|
|
|
|
uint8_t w1 = w[3 * i + 1];
|
|
|
|
@@ -375,8 +484,9 @@ template <typename U, int N, int bits>
|
|
|
|
|
inline void
|
|
|
|
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
U s[4] = {
|
|
|
|
@@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 5) {
|
|
|
|
|
for (int i = 0; i < (N / 8); i++) {
|
|
|
|
|
w_local += 8 * i;
|
|
|
|
|
w += 5 * i;
|
|
|
|
|
|
|
|
|
|
w_local[0] = (w[0] & 0x1f) * scale + bias;
|
|
|
|
|
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
|
|
|
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
|
|
|
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
|
|
|
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
|
|
|
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
|
|
|
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
|
|
|
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
else if (bits == 6) {
|
|
|
|
|
for (int i = 0; i < (N / 4); i++) {
|
|
|
|
|
w_local += 4 * i;
|
|
|
|
|
w += 3 * i;
|
|
|
|
|
|
|
|
|
|
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
|
|
|
|
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
|
|
|
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
|
|
@@ -452,11 +577,12 @@ struct QuantizedBlockLoader {
|
|
|
|
|
group_size % BCOLS == 0,
|
|
|
|
|
"The group size should be divisible by the columns");
|
|
|
|
|
static_assert(
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
|
|
|
|
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
|
|
|
bits == 8,
|
|
|
|
|
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
|
|
|
|
|
|
|
|
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
|
|
|
|
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
|
|
|
|
MLX_MTL_CONST short n_reads =
|
|
|
|
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
|
|
|
@@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl(
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int results_per_simdgroup = 4;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
|
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
|
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
|
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
@@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl(
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int results_per_simdgroup = 4;
|
|
|
|
|
constexpr int packs_per_thread = 1;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
|
|
|
|
|
|
|
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
|
|
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
|
|
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
|
|
@@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl(
|
|
|
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int num_simdgroups = 2;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
|
|
|
|
|
constexpr int tn = 32 / pack_factor;
|
|
|
|
|
constexpr int block_size = SIMD_SIZE;
|
|
|
|
|
|
|
|
|
@@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl(
|
|
|
|
|
|
|
|
|
|
constexpr int WM = 2;
|
|
|
|
|
constexpr int WN = 2;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
|
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
|
|
|
|
|
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
|
|
|
using mma_t = mlx::steel::
|
|
|
|
@@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl(
|
|
|
|
|
|
|
|
|
|
constexpr int WM = 2;
|
|
|
|
|
constexpr int WN = 2;
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
|
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
// Instantiate the appropriate BlockMMA and Loader
|
|
|
|
|
using mma_t = mlx::steel::
|
|
|
|
@@ -2120,11 +2247,10 @@ template <
|
|
|
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
|
|
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
|
|
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
using mma_t = mlx::steel::BlockMMA<
|
|
|
|
|
T,
|
|
|
|
@@ -2305,13 +2431,13 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
constexpr float eps = 1e-7;
|
|
|
|
|
constexpr int simd_size = 32;
|
|
|
|
|
constexpr float n_bins = (1 << bits) - 1;
|
|
|
|
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
constexpr int values_per_reduce = group_size / simd_size;
|
|
|
|
|
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
|
|
|
|
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
|
|
|
|
|
constexpr int writes_per_pack =
|
|
|
|
|
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
|
|
|
|
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
|
|
|
|
|
static_assert(
|
|
|
|
|
group_size % simd_size == 0,
|
|
|
|
@@ -2354,8 +2480,8 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
biases[gindex] = static_cast<T>(bias);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
|
|
|
|
uint32_t output = 0;
|
|
|
|
|
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
|
|
|
|
|
OutType output = 0;
|
|
|
|
|
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int i = 0; i < values_per_reduce; i++) {
|
|
|
|
@@ -2363,27 +2489,35 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
if (bits == 8) {
|
|
|
|
|
output = val;
|
|
|
|
|
} else {
|
|
|
|
|
output += val << (bits * (i % packs_per_int));
|
|
|
|
|
output |= val << (bits * (i % pack_factor));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (packs_per_int < values_per_reduce &&
|
|
|
|
|
i % packs_per_int == packs_per_int - 1) {
|
|
|
|
|
out[out_index + i / packs_per_int] = output;
|
|
|
|
|
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
|
|
|
|
|
out[out_index + i / pack_factor] = output;
|
|
|
|
|
output = 0;
|
|
|
|
|
} else {
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int j = 1; j < writes_per_reduce; j++) {
|
|
|
|
|
uint8_t sval = simd_shuffle_down(val, j);
|
|
|
|
|
output += sval << (bits * (j * values_per_reduce + i));
|
|
|
|
|
output |= static_cast<OutType>(sval)
|
|
|
|
|
<< (bits * (j * values_per_reduce + i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (bits == 3 || bits == 6) {
|
|
|
|
|
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
|
|
|
|
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
|
|
|
out[out_index] = output & 0xff;
|
|
|
|
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
|
|
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
|
|
|
}
|
|
|
|
|
} else if (bits == 5) {
|
|
|
|
|
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
|
|
|
out[out_index] = output & 0xff;
|
|
|
|
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
|
|
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
|
|
|
out[out_index + 3] = (output & 0xff000000) >> 24;
|
|
|
|
|
out[out_index + 4] = (output & 0xff00000000) >> 32;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
|
|
|
|
out[out_index / writes_per_reduce] = output;
|
|
|
|
@@ -2399,12 +2533,11 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
device T* out [[buffer(3)]],
|
|
|
|
|
uint2 index [[thread_position_in_grid]],
|
|
|
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
|
|
|
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
|
|
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
|
|
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
|
|
|
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
|
|
|
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
|
|
|
|
|
|
|
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
|
|
|
size_t oindex = offset * packs_per_int;
|
|
|
|
|
size_t oindex = offset * pack_factor;
|
|
|
|
|
size_t gindex = oindex / group_size;
|
|
|
|
|
T scale = scales[gindex];
|
|
|
|
|
T bias = biases[gindex];
|
|
|
|
@@ -2421,7 +2554,16 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
|
|
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
|
|
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
|
|
|
|
|
|
|
|
} else if (bits == 5) {
|
|
|
|
|
w += offset * bytes_per_pack;
|
|
|
|
|
out[0] = (w[0] & 0x1f) * scale + bias;
|
|
|
|
|
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
|
|
|
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
|
|
|
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
|
|
|
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
|
|
|
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
|
|
|
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
|
|
|
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
|
|
|
} else if (bits == 6) {
|
|
|
|
|
w += offset * bytes_per_pack;
|
|
|
|
|
out[0] = (w[0] & 0x3f) * scale + bias;
|
|
|
|
@@ -2431,7 +2573,7 @@ template <typename T, const int group_size, const int bits>
|
|
|
|
|
} else {
|
|
|
|
|
uint val = w[offset];
|
|
|
|
|
#pragma clang loop unroll(full)
|
|
|
|
|
for (int i = 0; i < packs_per_int; i++) {
|
|
|
|
|
for (int i = 0; i < pack_factor; i++) {
|
|
|
|
|
uint8_t d;
|
|
|
|
|
if (bits == 2) {
|
|
|
|
|
d = (val >> (bits * i)) & 0x03;
|
|
|
|
|